1#!/usr/bin/python2.4
2#
3# Copyright 2014 Google Inc. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18"""Oauth2client.file tests
19
20Unit tests for oauth2client.file
21"""
22
23__author__ = 'jcgregorio@google.com (Joe Gregorio)'
24
25import copy
26import datetime
27import json
28import os
29import pickle
30import stat
31import tempfile
32import unittest
33
34from .http_mock import HttpMockSequence
35import six
36
37from oauth2client import file
38from oauth2client import locked_file
39from oauth2client import multistore_file
40from oauth2client import util
41from oauth2client.client import AccessTokenCredentials
42from oauth2client.client import OAuth2Credentials
43from six.moves import http_client
44try:
45  # Python2
46  from future_builtins import oct
47except:
48  pass
49
50
51FILENAME = tempfile.mktemp('oauth2client_test.data')
52
53
54class OAuth2ClientFileTests(unittest.TestCase):
55
56  def tearDown(self):
57    try:
58      os.unlink(FILENAME)
59    except OSError:
60      pass
61
62  def setUp(self):
63    try:
64      os.unlink(FILENAME)
65    except OSError:
66      pass
67
68  def create_test_credentials(self, client_id='some_client_id',
69                              expiration=None):
70    access_token = 'foo'
71    client_secret = 'cOuDdkfjxxnv+'
72    refresh_token = '1/0/a.df219fjls0'
73    token_expiry = expiration or datetime.datetime.utcnow()
74    token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
75    user_agent = 'refresh_checker/1.0'
76
77    credentials = OAuth2Credentials(
78        access_token, client_id, client_secret,
79        refresh_token, token_expiry, token_uri,
80        user_agent)
81    return credentials
82
83  def test_non_existent_file_storage(self):
84    s = file.Storage(FILENAME)
85    credentials = s.get()
86    self.assertEquals(None, credentials)
87
88  def test_no_sym_link_credentials(self):
89    if hasattr(os, 'symlink'):
90      SYMFILENAME = FILENAME + '.sym'
91      os.symlink(FILENAME, SYMFILENAME)
92      s = file.Storage(SYMFILENAME)
93      try:
94        s.get()
95        self.fail('Should have raised an exception.')
96      except file.CredentialsFileSymbolicLinkError:
97        pass
98      finally:
99        os.unlink(SYMFILENAME)
100
101  def test_pickle_and_json_interop(self):
102    # Write a file with a pickled OAuth2Credentials.
103    credentials = self.create_test_credentials()
104
105    f = open(FILENAME, 'wb')
106    pickle.dump(credentials, f)
107    f.close()
108
109    # Storage should be not be able to read that object, as the capability to
110    # read and write credentials as pickled objects has been removed.
111    s = file.Storage(FILENAME)
112    read_credentials = s.get()
113    self.assertEquals(None, read_credentials)
114
115    # Now write it back out and confirm it has been rewritten as JSON
116    s.put(credentials)
117    with open(FILENAME) as f:
118      data = json.load(f)
119
120    self.assertEquals(data['access_token'], 'foo')
121    self.assertEquals(data['_class'], 'OAuth2Credentials')
122    self.assertEquals(data['_module'], OAuth2Credentials.__module__)
123
124  def test_token_refresh_store_expired(self):
125    expiration = datetime.datetime.utcnow() - datetime.timedelta(minutes=15)
126    credentials = self.create_test_credentials(expiration=expiration)
127
128    s = file.Storage(FILENAME)
129    s.put(credentials)
130    credentials = s.get()
131    new_cred = copy.copy(credentials)
132    new_cred.access_token = 'bar'
133    s.put(new_cred)
134
135    access_token = '1/3w'
136    token_response = {'access_token': access_token, 'expires_in': 3600}
137    http = HttpMockSequence([
138        ({'status': '200'}, json.dumps(token_response).encode('utf-8')),
139    ])
140
141    credentials._refresh(http.request)
142    self.assertEquals(credentials.access_token, access_token)
143
144  def test_token_refresh_store_expires_soon(self):
145    # Tests the case where an access token that is valid when it is read from
146    # the store expires before the original request succeeds.
147    expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
148    credentials = self.create_test_credentials(expiration=expiration)
149
150    s = file.Storage(FILENAME)
151    s.put(credentials)
152    credentials = s.get()
153    new_cred = copy.copy(credentials)
154    new_cred.access_token = 'bar'
155    s.put(new_cred)
156
157    access_token = '1/3w'
158    token_response = {'access_token': access_token, 'expires_in': 3600}
159    http = HttpMockSequence([
160        ({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
161        ({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
162        ({'status': str(http_client.OK)},
163         json.dumps(token_response).encode('utf-8')),
164        ({'status': str(http_client.OK)},
165         b'Valid response to original request')
166    ])
167
168    credentials.authorize(http)
169    http.request('https://example.com')
170    self.assertEqual(credentials.access_token, access_token)
171
172  def test_token_refresh_good_store(self):
173    expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
174    credentials = self.create_test_credentials(expiration=expiration)
175
176    s = file.Storage(FILENAME)
177    s.put(credentials)
178    credentials = s.get()
179    new_cred = copy.copy(credentials)
180    new_cred.access_token = 'bar'
181    s.put(new_cred)
182
183    credentials._refresh(lambda x: x)
184    self.assertEquals(credentials.access_token, 'bar')
185
186  def test_token_refresh_stream_body(self):
187    expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
188    credentials = self.create_test_credentials(expiration=expiration)
189
190    s = file.Storage(FILENAME)
191    s.put(credentials)
192    credentials = s.get()
193    new_cred = copy.copy(credentials)
194    new_cred.access_token = 'bar'
195    s.put(new_cred)
196
197    valid_access_token = '1/3w'
198    token_response = {'access_token': valid_access_token, 'expires_in': 3600}
199    http = HttpMockSequence([
200        ({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
201        ({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
202        ({'status': str(http_client.OK)},
203         json.dumps(token_response).encode('utf-8')),
204        ({'status': str(http_client.OK)}, 'echo_request_body')
205    ])
206
207    body = six.StringIO('streaming body')
208
209    credentials.authorize(http)
210    _, content = http.request('https://example.com', body=body)
211    self.assertEqual(content, 'streaming body')
212    self.assertEqual(credentials.access_token, valid_access_token)
213
214  def test_credentials_delete(self):
215    credentials = self.create_test_credentials()
216
217    s = file.Storage(FILENAME)
218    s.put(credentials)
219    credentials = s.get()
220    self.assertNotEquals(None, credentials)
221    s.delete()
222    credentials = s.get()
223    self.assertEquals(None, credentials)
224
225  def test_access_token_credentials(self):
226    access_token = 'foo'
227    user_agent = 'refresh_checker/1.0'
228
229    credentials = AccessTokenCredentials(access_token, user_agent)
230
231    s = file.Storage(FILENAME)
232    credentials = s.put(credentials)
233    credentials = s.get()
234
235    self.assertNotEquals(None, credentials)
236    self.assertEquals('foo', credentials.access_token)
237    mode = os.stat(FILENAME).st_mode
238
239    if os.name == 'posix':
240      self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
241
242  def test_read_only_file_fail_lock(self):
243    credentials = self.create_test_credentials()
244
245    open(FILENAME, 'a+b').close()
246    os.chmod(FILENAME, 0o400)
247
248    store = multistore_file.get_credential_storage(
249        FILENAME,
250        credentials.client_id,
251        credentials.user_agent,
252        ['some-scope', 'some-other-scope'])
253
254    store.put(credentials)
255    if os.name == 'posix':
256      self.assertTrue(store._multistore._read_only)
257    os.chmod(FILENAME, 0o600)
258
259  def test_multistore_no_symbolic_link_files(self):
260    if hasattr(os, 'symlink'):
261      SYMFILENAME = FILENAME + 'sym'
262      os.symlink(FILENAME, SYMFILENAME)
263      store = multistore_file.get_credential_storage(
264          SYMFILENAME,
265          'some_client_id',
266          'user-agent/1.0',
267          ['some-scope', 'some-other-scope'])
268      try:
269        store.get()
270        self.fail('Should have raised an exception.')
271      except locked_file.CredentialsFileSymbolicLinkError:
272        pass
273      finally:
274        os.unlink(SYMFILENAME)
275
276  def test_multistore_non_existent_file(self):
277    store = multistore_file.get_credential_storage(
278        FILENAME,
279        'some_client_id',
280        'user-agent/1.0',
281        ['some-scope', 'some-other-scope'])
282
283    credentials = store.get()
284    self.assertEquals(None, credentials)
285
286  def test_multistore_file(self):
287    credentials = self.create_test_credentials()
288
289    store = multistore_file.get_credential_storage(
290        FILENAME,
291        credentials.client_id,
292        credentials.user_agent,
293        ['some-scope', 'some-other-scope'])
294
295    store.put(credentials)
296    credentials = store.get()
297
298    self.assertNotEquals(None, credentials)
299    self.assertEquals('foo', credentials.access_token)
300
301    store.delete()
302    credentials = store.get()
303
304    self.assertEquals(None, credentials)
305
306    if os.name == 'posix':
307      self.assertEquals('0o600', oct(stat.S_IMODE(os.stat(FILENAME).st_mode)))
308
309  def test_multistore_file_custom_key(self):
310    credentials = self.create_test_credentials()
311
312    custom_key = {'myapp': 'testing', 'clientid': 'some client'}
313    store = multistore_file.get_credential_storage_custom_key(
314        FILENAME, custom_key)
315
316    store.put(credentials)
317    stored_credentials = store.get()
318
319    self.assertNotEquals(None, stored_credentials)
320    self.assertEqual(credentials.access_token, stored_credentials.access_token)
321
322    store.delete()
323    stored_credentials = store.get()
324
325    self.assertEquals(None, stored_credentials)
326
327  def test_multistore_file_custom_string_key(self):
328    credentials = self.create_test_credentials()
329
330    # store with string key
331    store = multistore_file.get_credential_storage_custom_string_key(
332        FILENAME, 'mykey')
333
334    store.put(credentials)
335    stored_credentials = store.get()
336
337    self.assertNotEquals(None, stored_credentials)
338    self.assertEqual(credentials.access_token, stored_credentials.access_token)
339
340    # try retrieving with a dictionary
341    store_dict = multistore_file.get_credential_storage_custom_string_key(
342        FILENAME, {'key': 'mykey'})
343    stored_credentials = store.get()
344    self.assertNotEquals(None, stored_credentials)
345    self.assertEqual(credentials.access_token, stored_credentials.access_token)
346
347    store.delete()
348    stored_credentials = store.get()
349
350    self.assertEquals(None, stored_credentials)
351
352  def test_multistore_file_backwards_compatibility(self):
353    credentials = self.create_test_credentials()
354    scopes = ['scope1', 'scope2']
355
356    # store the credentials using the legacy key method
357    store = multistore_file.get_credential_storage(
358        FILENAME, 'client_id', 'user_agent', scopes)
359    store.put(credentials)
360
361    # retrieve the credentials using a custom key that matches the legacy key
362    key = {'clientId': 'client_id', 'userAgent': 'user_agent',
363           'scope': util.scopes_to_string(scopes)}
364    store = multistore_file.get_credential_storage_custom_key(FILENAME, key)
365    stored_credentials = store.get()
366
367    self.assertEqual(credentials.access_token, stored_credentials.access_token)
368
369  def test_multistore_file_get_all_keys(self):
370    # start with no keys
371    keys = multistore_file.get_all_credential_keys(FILENAME)
372    self.assertEquals([], keys)
373
374    # store credentials
375    credentials = self.create_test_credentials(client_id='client1')
376    custom_key = {'myapp': 'testing', 'clientid': 'client1'}
377    store1 = multistore_file.get_credential_storage_custom_key(
378        FILENAME, custom_key)
379    store1.put(credentials)
380
381    keys = multistore_file.get_all_credential_keys(FILENAME)
382    self.assertEquals([custom_key], keys)
383
384    # store more credentials
385    credentials = self.create_test_credentials(client_id='client2')
386    string_key = 'string_key'
387    store2 = multistore_file.get_credential_storage_custom_string_key(
388        FILENAME, string_key)
389    store2.put(credentials)
390
391    keys = multistore_file.get_all_credential_keys(FILENAME)
392    self.assertEquals(2, len(keys))
393    self.assertTrue(custom_key in keys)
394    self.assertTrue({'key': string_key} in keys)
395
396    # back to no keys
397    store1.delete()
398    store2.delete()
399    keys = multistore_file.get_all_credential_keys(FILENAME)
400    self.assertEquals([], keys)
401
402
403if __name__ == '__main__':
404  unittest.main()
405