1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "google_apis/cup/client_update_protocol.h"
6
7#include "base/base64.h"
8#include "base/logging.h"
9#include "base/memory/scoped_ptr.h"
10#include "base/sha1.h"
11#include "base/strings/string_util.h"
12#include "base/strings/stringprintf.h"
13#include "crypto/hmac.h"
14#include "crypto/random.h"
15
16namespace {
17
18base::StringPiece ByteVectorToSP(const std::vector<uint8>& vec) {
19  if (vec.empty())
20    return base::StringPiece();
21
22  return base::StringPiece(reinterpret_cast<const char*>(&vec[0]), vec.size());
23}
24
25// This class needs to implement the same hashing and signing functions as the
26// Google Update server; for now, this is SHA-1 and HMAC-SHA1, but this may
27// change to SHA-256 in the near future.  For this reason, all primitives are
28// wrapped.  The name "SymSign" is used to mirror the CUP specification.
29size_t HashDigestSize() {
30  return base::kSHA1Length;
31}
32
33std::vector<uint8> Hash(const std::vector<uint8>& data) {
34  std::vector<uint8> result(HashDigestSize());
35  base::SHA1HashBytes(data.empty() ? NULL : &data[0],
36                      data.size(),
37                      &result[0]);
38  return result;
39}
40
41std::vector<uint8> Hash(const base::StringPiece& sdata) {
42  std::vector<uint8> result(HashDigestSize());
43  base::SHA1HashBytes(sdata.empty() ?
44                          NULL :
45                          reinterpret_cast<const unsigned char*>(sdata.data()),
46                      sdata.length(),
47                      &result[0]);
48  return result;
49}
50
51std::vector<uint8> SymConcat(uint8 id,
52                             const std::vector<uint8>* h1,
53                             const std::vector<uint8>* h2,
54                             const std::vector<uint8>* h3) {
55  std::vector<uint8> result;
56  result.push_back(id);
57  const std::vector<uint8>* args[] = { h1, h2, h3 };
58  for (size_t i = 0; i != arraysize(args); ++i) {
59    if (args[i]) {
60      DCHECK_EQ(args[i]->size(), HashDigestSize());
61      result.insert(result.end(), args[i]->begin(), args[i]->end());
62    }
63  }
64
65  return result;
66}
67
68std::vector<uint8> SymSign(const std::vector<uint8>& key,
69                           const std::vector<uint8>& hashes) {
70  DCHECK(!key.empty());
71  DCHECK(!hashes.empty());
72
73  crypto::HMAC hmac(crypto::HMAC::SHA1);
74  if (!hmac.Init(&key[0], key.size()))
75    return std::vector<uint8>();
76
77  std::vector<uint8> result(hmac.DigestLength());
78  if (!hmac.Sign(ByteVectorToSP(hashes), &result[0], result.size()))
79    return std::vector<uint8>();
80
81  return result;
82}
83
84bool SymSignVerify(const std::vector<uint8>& key,
85                   const std::vector<uint8>& hashes,
86                   const std::vector<uint8>& server_proof) {
87  DCHECK(!key.empty());
88  DCHECK(!hashes.empty());
89  DCHECK(!server_proof.empty());
90
91  crypto::HMAC hmac(crypto::HMAC::SHA1);
92  if (!hmac.Init(&key[0], key.size()))
93    return false;
94
95  return hmac.Verify(ByteVectorToSP(hashes), ByteVectorToSP(server_proof));
96}
97
98// RsaPad() is implemented as described in the CUP spec.  It is NOT a general
99// purpose padding algorithm.
100std::vector<uint8> RsaPad(size_t rsa_key_size,
101                          const std::vector<uint8>& entropy) {
102  DCHECK_GE(rsa_key_size, HashDigestSize());
103
104  // The result gets padded with zeros if the result size is greater than
105  // the size of the buffer provided by the caller.
106  std::vector<uint8> result(entropy);
107  result.resize(rsa_key_size - HashDigestSize());
108
109  // For use with RSA, the input needs to be smaller than the RSA modulus,
110  // which has always the msb set.
111  result[0] &= 127;  // Reset msb
112  result[0] |= 64;   // Set second highest bit.
113
114  std::vector<uint8> digest = Hash(result);
115  result.insert(result.end(), digest.begin(), digest.end());
116  DCHECK_EQ(result.size(), rsa_key_size);
117  return result;
118}
119
120// CUP passes the versioned secret in the query portion of the URL for the
121// update check service -- and that means that a URL-safe variant of Base64 is
122// needed.  Call the standard Base64 encoder/decoder and then apply fixups.
123std::string UrlSafeB64Encode(const std::vector<uint8>& data) {
124  std::string result;
125  base::Base64Encode(ByteVectorToSP(data), &result);
126
127  // Do an tr|+/|-_| on the output, and strip any '=' padding.
128  for (std::string::iterator it = result.begin(); it != result.end(); ++it) {
129    switch (*it) {
130      case '+':
131        *it = '-';
132        break;
133      case '/':
134        *it = '_';
135        break;
136      default:
137        break;
138    }
139  }
140  base::TrimString(result, "=", &result);
141
142  return result;
143}
144
145std::vector<uint8> UrlSafeB64Decode(const base::StringPiece& input) {
146  std::string unsafe(input.begin(), input.end());
147  for (std::string::iterator it = unsafe.begin(); it != unsafe.end(); ++it) {
148    switch (*it) {
149      case '-':
150        *it = '+';
151        break;
152      case '_':
153        *it = '/';
154        break;
155      default:
156        break;
157    }
158  }
159  if (unsafe.length() % 4)
160    unsafe.append(4 - (unsafe.length() % 4), '=');
161
162  std::string decoded;
163  if (!base::Base64Decode(unsafe, &decoded))
164    return std::vector<uint8>();
165
166  return std::vector<uint8>(decoded.begin(), decoded.end());
167}
168
169}  // end namespace
170
171ClientUpdateProtocol::ClientUpdateProtocol(int key_version)
172    : pub_key_version_(key_version) {
173}
174
175scoped_ptr<ClientUpdateProtocol> ClientUpdateProtocol::Create(
176    int key_version,
177    const base::StringPiece& public_key) {
178  DCHECK_GT(key_version, 0);
179  DCHECK(!public_key.empty());
180
181  scoped_ptr<ClientUpdateProtocol> result(
182      new ClientUpdateProtocol(key_version));
183  if (!result)
184    return scoped_ptr<ClientUpdateProtocol>();
185
186  if (!result->LoadPublicKey(public_key))
187    return scoped_ptr<ClientUpdateProtocol>();
188
189  if (!result->BuildRandomSharedKey())
190    return scoped_ptr<ClientUpdateProtocol>();
191
192  return result.Pass();
193}
194
195std::string ClientUpdateProtocol::GetVersionedSecret() const {
196  return base::StringPrintf("%d:%s",
197                            pub_key_version_,
198                            UrlSafeB64Encode(encrypted_key_source_).c_str());
199}
200
201bool ClientUpdateProtocol::SignRequest(const base::StringPiece& url,
202                                       const base::StringPiece& request_body,
203                                       std::string* client_proof) {
204  DCHECK(!encrypted_key_source_.empty());
205  DCHECK(!url.empty());
206  DCHECK(!request_body.empty());
207  DCHECK(client_proof);
208
209  // Compute the challenge hash:
210  //   hw = HASH(HASH(v|w)|HASH(request_url)|HASH(body)).
211  // Keep the challenge hash for later to validate the server's response.
212  std::vector<uint8> internal_hashes;
213
214  std::vector<uint8> h;
215  h = Hash(GetVersionedSecret());
216  internal_hashes.insert(internal_hashes.end(), h.begin(), h.end());
217  h = Hash(url);
218  internal_hashes.insert(internal_hashes.end(), h.begin(), h.end());
219  h = Hash(request_body);
220  internal_hashes.insert(internal_hashes.end(), h.begin(), h.end());
221  DCHECK_EQ(internal_hashes.size(), 3 * HashDigestSize());
222
223  client_challenge_hash_ = Hash(internal_hashes);
224
225  // Sign the challenge hash (hw) using the shared key (sk) to produce the
226  // client proof (cp).
227  std::vector<uint8> raw_client_proof =
228      SymSign(shared_key_, SymConcat(3, &client_challenge_hash_, NULL, NULL));
229  if (raw_client_proof.empty()) {
230    client_challenge_hash_.clear();
231    return false;
232  }
233
234  *client_proof = UrlSafeB64Encode(raw_client_proof);
235  return true;
236}
237
238bool ClientUpdateProtocol::ValidateResponse(
239    const base::StringPiece& response_body,
240    const base::StringPiece& server_cookie,
241    const base::StringPiece& server_proof) {
242  DCHECK(!client_challenge_hash_.empty());
243
244  if (response_body.empty() || server_cookie.empty() || server_proof.empty())
245    return false;
246
247  // Decode the server proof from URL-safe Base64 to a binary HMAC for the
248  // response.
249  std::vector<uint8> sp_decoded = UrlSafeB64Decode(server_proof);
250  if (sp_decoded.empty())
251    return false;
252
253  // If the request was received by the server, the server will use its
254  // private key to decrypt |w_|, yielding the original contents of |r_|.
255  // The server can then recreate |sk_|, compute |hw_|, and SymSign(3|hw)
256  // to ensure that the cp matches the contents.  It will then use |sk_|
257  // to sign its response, producing the server proof |sp|.
258  std::vector<uint8> hm = Hash(response_body);
259  std::vector<uint8> hc = Hash(server_cookie);
260  return SymSignVerify(shared_key_,
261                       SymConcat(1, &client_challenge_hash_, &hm, &hc),
262                       sp_decoded);
263}
264
265bool ClientUpdateProtocol::BuildRandomSharedKey() {
266  DCHECK_GE(PublicKeyLength(), HashDigestSize());
267
268  // Start by generating some random bytes that are suitable to be encrypted;
269  // this will be the source of the shared HMAC key that client and server use.
270  // (CUP specification calls this "r".)
271  std::vector<uint8> key_source;
272  std::vector<uint8> entropy(PublicKeyLength() - HashDigestSize());
273  crypto::RandBytes(&entropy[0], entropy.size());
274  key_source = RsaPad(PublicKeyLength(), entropy);
275
276  return DeriveSharedKey(key_source);
277}
278
279bool ClientUpdateProtocol::SetSharedKeyForTesting(
280  const base::StringPiece& key_source) {
281  DCHECK_EQ(key_source.length(), PublicKeyLength());
282
283  return DeriveSharedKey(std::vector<uint8>(key_source.begin(),
284                                            key_source.end()));
285}
286
287bool ClientUpdateProtocol::DeriveSharedKey(const std::vector<uint8>& source) {
288  DCHECK(!source.empty());
289  DCHECK_GE(source.size(), HashDigestSize());
290  DCHECK_EQ(source.size(), PublicKeyLength());
291
292  // Hash the key source (r) to generate a new shared HMAC key (sk').
293  shared_key_ = Hash(source);
294
295  // Encrypt the key source (r) using the public key (pk[v]) to generate the
296  // encrypted key source (w).
297  if (!EncryptKeySource(source))
298    return false;
299  if (encrypted_key_source_.size() != PublicKeyLength())
300    return false;
301
302  return true;
303}
304