1package com.android.hotspot2.est;
2
3import android.net.Network;
4import android.util.Base64;
5import android.util.Log;
6
7import com.android.hotspot2.OMADMAdapter;
8import com.android.hotspot2.asn1.Asn1Class;
9import com.android.hotspot2.asn1.Asn1Constructed;
10import com.android.hotspot2.asn1.Asn1Decoder;
11import com.android.hotspot2.asn1.Asn1ID;
12import com.android.hotspot2.asn1.Asn1Integer;
13import com.android.hotspot2.asn1.Asn1Object;
14import com.android.hotspot2.asn1.Asn1Oid;
15import com.android.hotspot2.asn1.OidMappings;
16import com.android.hotspot2.osu.HTTPHandler;
17import com.android.hotspot2.osu.OSUFlowManager;
18import com.android.hotspot2.osu.OSUSocketFactory;
19import com.android.hotspot2.osu.commands.GetCertData;
20import com.android.hotspot2.pps.HomeSP;
21import com.android.hotspot2.utils.HTTPMessage;
22import com.android.hotspot2.utils.HTTPResponse;
23import com.android.org.bouncycastle.asn1.ASN1Encodable;
24import com.android.org.bouncycastle.asn1.ASN1EncodableVector;
25import com.android.org.bouncycastle.asn1.ASN1Set;
26import com.android.org.bouncycastle.asn1.DERBitString;
27import com.android.org.bouncycastle.asn1.DEREncodableVector;
28import com.android.org.bouncycastle.asn1.DERIA5String;
29import com.android.org.bouncycastle.asn1.DERObjectIdentifier;
30import com.android.org.bouncycastle.asn1.DERPrintableString;
31import com.android.org.bouncycastle.asn1.DERSet;
32import com.android.org.bouncycastle.asn1.x509.Attribute;
33import com.android.org.bouncycastle.jce.PKCS10CertificationRequest;
34import com.android.org.bouncycastle.jce.spec.ECNamedCurveGenParameterSpec;
35
36import java.io.ByteArrayInputStream;
37import java.io.IOException;
38import java.net.URL;
39import java.nio.ByteBuffer;
40import java.nio.charset.StandardCharsets;
41import java.security.AlgorithmParameters;
42import java.security.GeneralSecurityException;
43import java.security.KeyPair;
44import java.security.KeyPairGenerator;
45import java.security.KeyStore;
46import java.security.PrivateKey;
47import java.security.cert.CertificateFactory;
48import java.security.cert.X509Certificate;
49import java.util.ArrayList;
50import java.util.Arrays;
51import java.util.Collection;
52import java.util.HashMap;
53import java.util.HashSet;
54import java.util.Iterator;
55import java.util.List;
56import java.util.Map;
57import java.util.Set;
58
59import javax.net.ssl.KeyManager;
60import javax.security.auth.x500.X500Principal;
61
62//import com.android.org.bouncycastle.jce.provider.BouncyCastleProvider;
63
64public class ESTHandler implements AutoCloseable {
65    private static final String TAG = "HS2EST";
66    private static final int MinRSAKeySize = 2048;
67
68    private static final String CACERT_PATH = "/cacerts";
69    private static final String CSR_PATH = "/csrattrs";
70    private static final String SIMPLE_ENROLL_PATH = "/simpleenroll";
71    private static final String SIMPLE_REENROLL_PATH = "/simplereenroll";
72
73    private final URL mURL;
74    private final String mUser;
75    private final byte[] mPassword;
76    private final OSUSocketFactory mSocketFactory;
77    private final OMADMAdapter mOMADMAdapter;
78
79    private final List<X509Certificate> mCACerts = new ArrayList<>();
80    private final List<X509Certificate> mClientCerts = new ArrayList<>();
81    private PrivateKey mClientKey;
82
83    public ESTHandler(GetCertData certData, Network network, OMADMAdapter omadmAdapter,
84                      KeyManager km, KeyStore ks, HomeSP homeSP, OSUFlowManager.FlowType flowType)
85            throws IOException, GeneralSecurityException {
86        mURL = new URL(certData.getServer());
87        mUser = certData.getUserName();
88        mPassword = certData.getPassword();
89        mSocketFactory = OSUSocketFactory.getSocketFactory(ks, homeSP, flowType,
90                network, mURL, km, true);
91        mOMADMAdapter = omadmAdapter;
92    }
93
94    @Override
95    public void close() throws IOException {
96    }
97
98    public List<X509Certificate> getCACerts() {
99        return mCACerts;
100    }
101
102    public List<X509Certificate> getClientCerts() {
103        return mClientCerts;
104    }
105
106    public PrivateKey getClientKey() {
107        return mClientKey;
108    }
109
110    private static String indent(int amount) {
111        char[] indent = new char[amount * 2];
112        Arrays.fill(indent, ' ');
113        return new String(indent);
114    }
115
116    public void execute(boolean reenroll) throws IOException, GeneralSecurityException {
117        URL caURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
118                mURL.getFile() + CACERT_PATH);
119
120        HTTPResponse response;
121        try (HTTPHandler httpHandler = new HTTPHandler(StandardCharsets.ISO_8859_1, mSocketFactory,
122                mUser, mPassword)) {
123            response = httpHandler.doGetHTTP(caURL);
124
125            if (!"application/pkcs7-mime".equals(response.getHeaders().
126                    get(HTTPMessage.ContentTypeHeader))) {
127                throw new IOException("Unexpected Content-Type: " +
128                        response.getHeaders().get(HTTPMessage.ContentTypeHeader));
129            }
130            ByteBuffer octetBuffer = response.getBinaryPayload();
131            Collection<Asn1Object> pkcs7Content1 = Asn1Decoder.decode(octetBuffer);
132            for (Asn1Object asn1Object : pkcs7Content1) {
133                Log.d(TAG, "---");
134                Log.d(TAG, asn1Object.toString());
135            }
136            Log.d(TAG, CACERT_PATH);
137
138            mCACerts.addAll(unpackPkcs7(octetBuffer));
139            for (X509Certificate certificate : mCACerts) {
140                Log.d(TAG, "CA-Cert: " + certificate.getSubjectX500Principal());
141            }
142
143            /*
144            byte[] octets = new byte[octetBuffer.remaining()];
145            octetBuffer.duplicate().get(octets);
146            for (byte b : octets) {
147                System.out.printf("%02x ", b & 0xff);
148            }
149            Log.d(TAG, );
150            */
151
152            /* + BC
153            try {
154                byte[] octets = new byte[octetBuffer.remaining()];
155                octetBuffer.duplicate().get(octets);
156                ASN1InputStream asnin = new ASN1InputStream(octets);
157                for (int n = 0; n < 100; n++) {
158                    ASN1Primitive object = asnin.readObject();
159                    if (object == null) {
160                        break;
161                    }
162                    parseObject(object, 0);
163                }
164            }
165            catch (Throwable t) {
166                t.printStackTrace();
167            }
168
169            Collection<Asn1Object> pkcs7Content = Asn1Decoder.decode(octetBuffer);
170            for (Asn1Object asn1Object : pkcs7Content) {
171                Log.d(TAG, asn1Object);
172            }
173
174            if (pkcs7Content.size() != 1) {
175                throw new IOException("Unexpected pkcs 7 container: " + pkcs7Content.size());
176            }
177
178            Asn1Constructed pkcs7Root = (Asn1Constructed) pkcs7Content.iterator().next();
179            Iterator<Asn1ID> certPath = Arrays.asList(Pkcs7CertPath).iterator();
180            Asn1Object certObject = pkcs7Root.findObject(certPath);
181            if (certObject == null || certPath.hasNext()) {
182                throw new IOException("Failed to find cert; returned object " + certObject +
183                        ", path " + (certPath.hasNext() ? "short" : "exhausted"));
184            }
185
186            ByteBuffer certOctets = certObject.getPayload();
187            if (certOctets == null) {
188                throw new IOException("No cert payload in: " + certObject);
189            }
190
191            byte[] certBytes = new byte[certOctets.remaining()];
192            certOctets.get(certBytes);
193
194            CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
195            Certificate cert = certFactory.generateCertificate(new ByteArrayInputStream(certBytes));
196            Log.d(TAG, "EST Cert: " + cert);
197            */
198
199            URL csrURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
200                    mURL.getFile() + CSR_PATH);
201            response = httpHandler.doGetHTTP(csrURL);
202
203            octetBuffer = response.getBinaryPayload();
204            byte[] csrData = buildCSR(octetBuffer, mOMADMAdapter, httpHandler);
205
206        /**/
207            Collection<Asn1Object> o = Asn1Decoder.decode(ByteBuffer.wrap(csrData));
208            Log.d(TAG, "CSR:");
209            Log.d(TAG, o.iterator().next().toString());
210            Log.d(TAG, "End CSR.");
211        /**/
212
213            URL enrollURL = new URL(mURL.getProtocol(), mURL.getHost(), mURL.getPort(),
214                    mURL.getFile() + (reenroll ? SIMPLE_REENROLL_PATH : SIMPLE_ENROLL_PATH));
215            String data = Base64.encodeToString(csrData, Base64.DEFAULT);
216            octetBuffer = httpHandler.exchangeBinary(enrollURL, data, "application/pkcs10");
217
218            Collection<Asn1Object> pkcs7Content2 = Asn1Decoder.decode(octetBuffer);
219            for (Asn1Object asn1Object : pkcs7Content2) {
220                Log.d(TAG, "---");
221                Log.d(TAG, asn1Object.toString());
222            }
223            mClientCerts.addAll(unpackPkcs7(octetBuffer));
224            for (X509Certificate cert : mClientCerts) {
225                Log.d(TAG, cert.toString());
226            }
227        }
228    }
229
230    private static final Asn1ID sSEQUENCE = new Asn1ID(Asn1Decoder.TAG_SEQ, Asn1Class.Universal);
231    private static final Asn1ID sCTXT0 = new Asn1ID(0, Asn1Class.Context);
232    private static final int PKCS7DataVersion = 1;
233    private static final int PKCS7SignedDataVersion = 3;
234
235    private static List<X509Certificate> unpackPkcs7(ByteBuffer pkcs7)
236            throws IOException, GeneralSecurityException {
237        Collection<Asn1Object> pkcs7Content = Asn1Decoder.decode(pkcs7);
238
239        if (pkcs7Content.size() != 1) {
240            throw new IOException("Unexpected pkcs 7 container: " + pkcs7Content.size());
241        }
242
243        Asn1Object data = pkcs7Content.iterator().next();
244        if (!data.isConstructed() || !data.matches(sSEQUENCE)) {
245            throw new IOException("Expected SEQ OF, got " + data.toSimpleString());
246        } else if (data.getChildren().size() != 2) {
247            throw new IOException("Expected content info to have two children, got " +
248                    data.getChildren().size());
249        }
250
251        Iterator<Asn1Object> children = data.getChildren().iterator();
252        Asn1Object contentType = children.next();
253        if (!contentType.equals(Asn1Oid.PKCS7SignedData)) {
254            throw new IOException("Content not PKCS7 signed data");
255        }
256        Asn1Object content = children.next();
257        if (!content.isConstructed() || !content.matches(sCTXT0)) {
258            throw new IOException("Expected [CONTEXT 0] with one child, got " +
259                    content.toSimpleString() + ", " + content.getChildren().size());
260        }
261
262        Asn1Object signedData = content.getChildren().iterator().next();
263        Map<Integer, Asn1Object> itemMap = new HashMap<>();
264        for (Asn1Object item : signedData.getChildren()) {
265            if (itemMap.put(item.getTag(), item) != null && item.getTag() != Asn1Decoder.TAG_SET) {
266                throw new IOException("Duplicate item in SignedData: " + item.toSimpleString());
267            }
268        }
269
270        Asn1Object versionObject = itemMap.get(Asn1Decoder.TAG_INTEGER);
271        if (versionObject == null || !(versionObject instanceof Asn1Integer)) {
272            throw new IOException("Bad or missing PKCS7 version: " + versionObject);
273        }
274        int pkcs7version = (int) ((Asn1Integer) versionObject).getValue();
275        Asn1Object innerContentInfo = itemMap.get(Asn1Decoder.TAG_SEQ);
276        if (innerContentInfo == null ||
277                !innerContentInfo.isConstructed() ||
278                !innerContentInfo.matches(sSEQUENCE) ||
279                innerContentInfo.getChildren().size() != 1) {
280            throw new IOException("Bad or missing PKCS7 contentInfo");
281        }
282        Asn1Object contentID = innerContentInfo.getChildren().iterator().next();
283        if (pkcs7version == PKCS7DataVersion && !contentID.equals(Asn1Oid.PKCS7Data) ||
284                pkcs7version == PKCS7SignedDataVersion && !contentID.equals(Asn1Oid.PKCS7SignedData)) {
285            throw new IOException("Inner PKCS7 content (" + contentID +
286                    ") not expected for version " + pkcs7version);
287        }
288        Asn1Object certWrapper = itemMap.get(0);
289        if (certWrapper == null || !certWrapper.isConstructed() || !certWrapper.matches(sCTXT0)) {
290            throw new IOException("Expected [CONTEXT 0], got: " + certWrapper);
291        }
292
293        List<X509Certificate> certList = new ArrayList<>(certWrapper.getChildren().size());
294        CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
295        for (Asn1Object certObject : certWrapper.getChildren()) {
296            ByteBuffer certOctets = ((Asn1Constructed) certObject).getEncoding();
297            if (certOctets == null) {
298                throw new IOException("No cert payload in: " + certObject);
299            }
300            byte[] certBytes = new byte[certOctets.remaining()];
301            certOctets.get(certBytes);
302
303            certList.add((X509Certificate) certFactory.
304                    generateCertificate(new ByteArrayInputStream(certBytes)));
305        }
306        return certList;
307    }
308
309    private byte[] buildCSR(ByteBuffer octetBuffer, OMADMAdapter omadmAdapter,
310                            HTTPHandler httpHandler) throws IOException, GeneralSecurityException {
311
312        //Security.addProvider(new BouncyCastleProvider());
313
314        Log.d(TAG, "/csrattrs:");
315        /*
316        byte[] octets = new byte[octetBuffer.remaining()];
317        octetBuffer.duplicate().get(octets);
318        for (byte b : octets) {
319            System.out.printf("%02x ", b & 0xff);
320        }
321        */
322        Collection<Asn1Object> csrs = Asn1Decoder.decode(octetBuffer);
323        for (Asn1Object asn1Object : csrs) {
324            Log.d(TAG, asn1Object.toString());
325        }
326
327        if (csrs.size() != 1) {
328            throw new IOException("Unexpected object count in CSR attributes response: " +
329                    csrs.size());
330        }
331        Asn1Object sequence = csrs.iterator().next();
332        if (sequence.getClass() != Asn1Constructed.class) {
333            throw new IOException("Unexpected CSR attribute container: " + sequence);
334        }
335
336        String keyAlgo = null;
337        Asn1Oid keyAlgoOID = null;
338        String sigAlgo = null;
339        String curveName = null;
340        Asn1Oid pubCrypto = null;
341        int keySize = -1;
342        Map<Asn1Oid, ASN1Encodable> idAttributes = new HashMap<>();
343
344        for (Asn1Object child : sequence.getChildren()) {
345            if (child.getTag() == Asn1Decoder.TAG_OID) {
346                Asn1Oid oid = (Asn1Oid) child;
347                OidMappings.SigEntry sigEntry = OidMappings.getSigEntry(oid);
348                if (sigEntry != null) {
349                    sigAlgo = sigEntry.getSigAlgo();
350                    keyAlgoOID = sigEntry.getKeyAlgo();
351                    keyAlgo = OidMappings.getJCEName(keyAlgoOID);
352                } else if (oid.equals(OidMappings.sPkcs9AtChallengePassword)) {
353                    byte[] tlsUnique = httpHandler.getTLSUnique();
354                    if (tlsUnique != null) {
355                        idAttributes.put(oid, new DERPrintableString(
356                                Base64.encodeToString(tlsUnique, Base64.DEFAULT)));
357                    } else {
358                        Log.w(TAG, "Cannot retrieve TLS unique channel binding");
359                    }
360                }
361            } else if (child.getTag() == Asn1Decoder.TAG_SEQ) {
362                Asn1Oid oid = null;
363                Set<Asn1Oid> oidValues = new HashSet<>();
364                List<Asn1Object> values = new ArrayList<>();
365
366                for (Asn1Object attributeSeq : child.getChildren()) {
367                    if (attributeSeq.getTag() == Asn1Decoder.TAG_OID) {
368                        oid = (Asn1Oid) attributeSeq;
369                    } else if (attributeSeq.getTag() == Asn1Decoder.TAG_SET) {
370                        for (Asn1Object value : attributeSeq.getChildren()) {
371                            if (value.getTag() == Asn1Decoder.TAG_OID) {
372                                oidValues.add((Asn1Oid) value);
373                            } else {
374                                values.add(value);
375                            }
376                        }
377                    }
378                }
379                if (oid == null) {
380                    throw new IOException("Invalid attribute, no OID");
381                }
382                if (oid.equals(OidMappings.sExtensionRequest)) {
383                    for (Asn1Oid subOid : oidValues) {
384                        if (OidMappings.isIDAttribute(subOid)) {
385                            if (subOid.equals(OidMappings.sMAC)) {
386                                idAttributes.put(subOid, new DERIA5String(omadmAdapter.getMAC()));
387                            } else if (subOid.equals(OidMappings.sIMEI)) {
388                                idAttributes.put(subOid, new DERIA5String(omadmAdapter.getImei()));
389                            } else if (subOid.equals(OidMappings.sMEID)) {
390                                idAttributes.put(subOid, new DERBitString(omadmAdapter.getMeid()));
391                            } else if (subOid.equals(OidMappings.sDevID)) {
392                                idAttributes.put(subOid,
393                                        new DERPrintableString(omadmAdapter.getDevID()));
394                            }
395                        }
396                    }
397                } else if (OidMappings.getCryptoID(oid) != null) {
398                    pubCrypto = oid;
399                    if (!values.isEmpty()) {
400                        for (Asn1Object value : values) {
401                            if (value.getTag() == Asn1Decoder.TAG_INTEGER) {
402                                keySize = (int) ((Asn1Integer) value).getValue();
403                            }
404                        }
405                    }
406                    if (oid.equals(OidMappings.sAlgo_EC)) {
407                        if (oidValues.isEmpty()) {
408                            throw new IOException("No ECC curve name provided");
409                        }
410                        for (Asn1Oid value : oidValues) {
411                            curveName = OidMappings.getJCEName(value);
412                            if (curveName != null) {
413                                break;
414                            }
415                        }
416                        if (curveName == null) {
417                            throw new IOException("Found no ECC curve for " + oidValues);
418                        }
419                    }
420                }
421            }
422        }
423
424        if (keyAlgoOID == null) {
425            throw new IOException("No public key algorithm specified");
426        }
427        if (pubCrypto != null && !pubCrypto.equals(keyAlgoOID)) {
428            throw new IOException("Mismatching key algorithms");
429        }
430
431        if (keyAlgoOID.equals(OidMappings.sAlgo_RSA)) {
432            if (keySize < MinRSAKeySize) {
433                if (keySize >= 0) {
434                    Log.i(TAG, "Upgrading suggested RSA key size from " +
435                            keySize + " to " + MinRSAKeySize);
436                }
437                keySize = MinRSAKeySize;
438            }
439        }
440
441        Log.d(TAG, String.format("pub key '%s', signature '%s', ECC curve '%s', id-atts %s",
442                keyAlgo, sigAlgo, curveName, idAttributes));
443
444        /*
445          Ruckus:
446            SEQUENCE:
447              OID=1.2.840.113549.1.1.11 (algo_id_sha256WithRSAEncryption)
448
449          RFC-7030:
450            SEQUENCE:
451              OID=1.2.840.113549.1.9.7 (challengePassword)
452              SEQUENCE:
453                OID=1.2.840.10045.2.1 (algo_id_ecPublicKey)
454                SET:
455                  OID=1.3.132.0.34 (secp384r1)
456              SEQUENCE:
457                OID=1.2.840.113549.1.9.14 (extensionRequest)
458                SET:
459                  OID=1.3.6.1.1.1.1.22 (mac-address)
460              OID=1.2.840.10045.4.3.3 (eccdaWithSHA384)
461
462              1L, 3L, 6L, 1L, 1L, 1L, 1L, 22
463         */
464
465        // ECC Does not appear to be supported currently
466        KeyPairGenerator kpg = KeyPairGenerator.getInstance(keyAlgo);
467        if (curveName != null) {
468            AlgorithmParameters algorithmParameters = AlgorithmParameters.getInstance(keyAlgo);
469            algorithmParameters.init(new ECNamedCurveGenParameterSpec(curveName));
470            kpg.initialize(algorithmParameters
471                    .getParameterSpec(ECNamedCurveGenParameterSpec.class));
472        } else {
473            kpg.initialize(keySize);
474        }
475        KeyPair kp = kpg.generateKeyPair();
476
477        X500Principal subject = new X500Principal("CN=Android, O=Google, C=US");
478
479        mClientKey = kp.getPrivate();
480
481        // !!! Map the idAttributes into an ASN1Set of values to pass to
482        // the PKCS10CertificationRequest - this code is using outdated BC classes and
483        // has *not* been tested.
484        ASN1Set attributes;
485        if (!idAttributes.isEmpty()) {
486            ASN1EncodableVector payload = new DEREncodableVector();
487            for (Map.Entry<Asn1Oid, ASN1Encodable> entry : idAttributes.entrySet()) {
488                DERObjectIdentifier type = new DERObjectIdentifier(entry.getKey().toOIDString());
489                ASN1Set values = new DERSet(entry.getValue());
490                Attribute attribute = new Attribute(type, values);
491                payload.add(attribute);
492            }
493            attributes = new DERSet(payload);
494        } else {
495            attributes = null;
496        }
497
498        return new PKCS10CertificationRequest(sigAlgo, subject, kp.getPublic(),
499                attributes, mClientKey).getEncoded();
500    }
501}
502