1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License
15 */
16package com.android.voicemail.impl.transcribe.grpc;
17
18import android.content.Context;
19import android.content.pm.PackageInfo;
20import android.content.pm.PackageManager;
21import android.text.TextUtils;
22import com.android.dialer.common.Assert;
23import com.android.dialer.common.LogUtil;
24import com.android.voicemail.impl.transcribe.TranscriptionConfigProvider;
25import com.google.internal.communications.voicemailtranscription.v1.VoicemailTranscriptionServiceGrpc;
26import io.grpc.CallOptions;
27import io.grpc.Channel;
28import io.grpc.ClientCall;
29import io.grpc.ClientInterceptor;
30import io.grpc.ClientInterceptors;
31import io.grpc.ForwardingClientCall;
32import io.grpc.ManagedChannel;
33import io.grpc.ManagedChannelBuilder;
34import io.grpc.Metadata;
35import io.grpc.MethodDescriptor;
36import io.grpc.okhttp.OkHttpChannelBuilder;
37import java.security.MessageDigest;
38
39/**
40 * Factory for creating grpc clients that talk to the transcription server. This allows all clients
41 * to share the same channel, which is relatively expensive to create.
42 */
43public class TranscriptionClientFactory {
44  private static final String DIGEST_ALGORITHM_SHA1 = "SHA1";
45  private static final char[] HEX_UPPERCASE = {
46    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
47  };
48
49  private final TranscriptionConfigProvider configProvider;
50  private final ManagedChannel originalChannel;
51  private final String packageName;
52  private final String cert;
53
54  public TranscriptionClientFactory(Context context, TranscriptionConfigProvider configProvider) {
55    this(context, configProvider, getManagedChannel(configProvider));
56  }
57
58  public TranscriptionClientFactory(
59      Context context, TranscriptionConfigProvider configProvider, ManagedChannel managedChannel) {
60    this.configProvider = configProvider;
61    this.packageName = context.getPackageName();
62    this.cert = getCertificateFingerprint(context);
63    originalChannel = managedChannel;
64  }
65
66  public TranscriptionClient getClient() {
67    LogUtil.enterBlock("TranscriptionClientFactory.getClient");
68    Assert.checkState(!originalChannel.isShutdown());
69    Channel channel =
70        ClientInterceptors.intercept(
71            originalChannel,
72            new Interceptor(
73                packageName, cert, configProvider.getApiKey(), configProvider.getAuthToken()));
74    return new TranscriptionClient(VoicemailTranscriptionServiceGrpc.newBlockingStub(channel));
75  }
76
77  public void shutdown() {
78    LogUtil.enterBlock("TranscriptionClientFactory.shutdown");
79    originalChannel.shutdown();
80  }
81
82  private static ManagedChannel getManagedChannel(TranscriptionConfigProvider configProvider) {
83    ManagedChannelBuilder<OkHttpChannelBuilder> builder =
84        OkHttpChannelBuilder.forTarget(configProvider.getServerAddress());
85    // Only use plaintext for debugging
86    if (configProvider.shouldUsePlaintext()) {
87      // Just passing 'false' doesnt have the same effect as not setting this field
88      builder.usePlaintext(true);
89    }
90    return builder.build();
91  }
92
93  private static String getCertificateFingerprint(Context context) {
94    try {
95      PackageInfo packageInfo =
96          context
97              .getPackageManager()
98              .getPackageInfo(context.getPackageName(), PackageManager.GET_SIGNATURES);
99      if (packageInfo != null
100          && packageInfo.signatures != null
101          && packageInfo.signatures.length > 0) {
102        MessageDigest messageDigest = MessageDigest.getInstance(DIGEST_ALGORITHM_SHA1);
103        if (messageDigest == null) {
104          LogUtil.w(
105              "TranscriptionClientFactory.getCertificateFingerprint", "error getting digest.");
106          return null;
107        }
108        byte[] bytes = messageDigest.digest(packageInfo.signatures[0].toByteArray());
109        if (bytes == null) {
110          LogUtil.w(
111              "TranscriptionClientFactory.getCertificateFingerprint", "empty message digest.");
112          return null;
113        }
114
115        int length = bytes.length;
116        StringBuilder out = new StringBuilder(length * 2);
117        for (int i = 0; i < length; i++) {
118          out.append(HEX_UPPERCASE[(bytes[i] & 0xf0) >>> 4]);
119          out.append(HEX_UPPERCASE[bytes[i] & 0x0f]);
120        }
121        return out.toString();
122      } else {
123        LogUtil.w(
124            "TranscriptionClientFactory.getCertificateFingerprint",
125            "failed to get package signature.");
126      }
127    } catch (Exception e) {
128      LogUtil.e(
129          "TranscriptionClientFactory.getCertificateFingerprint",
130          "error getting certificate fingerprint.",
131          e);
132    }
133
134    return null;
135  }
136
137  private static final class Interceptor implements ClientInterceptor {
138    private final String packageName;
139    private final String cert;
140    private final String apiKey;
141    private final String authToken;
142
143    private static final Metadata.Key<String> API_KEY_HEADER =
144        Metadata.Key.of("X-Goog-Api-Key", Metadata.ASCII_STRING_MARSHALLER);
145    private static final Metadata.Key<String> ANDROID_PACKAGE_HEADER =
146        Metadata.Key.of("X-Android-Package", Metadata.ASCII_STRING_MARSHALLER);
147    private static final Metadata.Key<String> ANDROID_CERT_HEADER =
148        Metadata.Key.of("X-Android-Cert", Metadata.ASCII_STRING_MARSHALLER);
149    private static final Metadata.Key<String> AUTHORIZATION_HEADER =
150        Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER);
151
152    public Interceptor(String packageName, String cert, String apiKey, String authToken) {
153      this.packageName = packageName;
154      this.cert = cert;
155      this.apiKey = apiKey;
156      this.authToken = authToken;
157    }
158
159    @Override
160    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
161        MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
162      LogUtil.enterBlock(
163          "TranscriptionClientFactory.interceptCall, intercepted " + method.getFullMethodName());
164      ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
165
166      call =
167          new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(call) {
168            @Override
169            public void start(Listener<RespT> responseListener, Metadata headers) {
170              if (!TextUtils.isEmpty(packageName)) {
171                LogUtil.i(
172                    "TranscriptionClientFactory.interceptCall",
173                    "attaching package name: " + packageName);
174                headers.put(ANDROID_PACKAGE_HEADER, packageName);
175              }
176              if (!TextUtils.isEmpty(cert)) {
177                LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching android cert");
178                headers.put(ANDROID_CERT_HEADER, cert);
179              }
180              if (!TextUtils.isEmpty(apiKey)) {
181                LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching API Key");
182                headers.put(API_KEY_HEADER, apiKey);
183              }
184              if (!TextUtils.isEmpty(authToken)) {
185                LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching auth token");
186                headers.put(AUTHORIZATION_HEADER, "Bearer " + authToken);
187              }
188              super.start(responseListener, headers);
189            }
190          };
191      return call;
192    }
193  }
194}
195