1import re
2
3import mock
4import six
5from six.moves import http_client
6import unittest2
7
8from apitools.base.py import credentials_lib
9from apitools.base.py import util
10
11
12def CreateUriValidator(uri_regexp, content=''):
13    def CheckUri(uri, headers=None):
14        if 'X-Google-Metadata-Request' not in headers:
15            raise ValueError('Missing required header')
16        if uri_regexp.match(uri):
17            message = content
18            status = http_client.OK
19        else:
20            message = 'Expected uri matching pattern %s' % uri_regexp.pattern
21            status = http_client.BAD_REQUEST
22        return type('HttpResponse', (object,), {'status': status})(), message
23    return CheckUri
24
25
26class CredentialsLibTest(unittest2.TestCase):
27
28    def _GetServiceCreds(self, service_account_name=None, scopes=None):
29        kwargs = {}
30        if service_account_name is not None:
31            kwargs['service_account_name'] = service_account_name
32        service_account_name = service_account_name or 'default'
33
34        def MockMetadataCalls(request_url):
35            default_scopes = scopes or ['scope1']
36            if request_url.endswith('scopes'):
37                return six.StringIO(''.join(default_scopes))
38            elif request_url.endswith('service-accounts'):
39                return six.StringIO(service_account_name)
40            elif request_url.endswith(
41                    '/service-accounts/%s/token' % service_account_name):
42                return six.StringIO('{"access_token": "token"}')
43            self.fail('Unexpected HTTP request to %s' % request_url)
44
45        with mock.patch.object(credentials_lib, '_GceMetadataRequest',
46                               side_effect=MockMetadataCalls,
47                               autospec=True) as opener_mock:
48            with mock.patch.object(util, 'DetectGce',
49                                   autospec=True) as mock_detect:
50                mock_detect.return_value = True
51                validator = CreateUriValidator(
52                    re.compile(r'.*/%s/.*' % service_account_name),
53                    content='{"access_token": "token"}')
54                credentials = credentials_lib.GceAssertionCredentials(
55                    scopes, **kwargs)
56                self.assertIsNone(credentials._refresh(validator))
57            self.assertEqual(3, opener_mock.call_count)
58
59    def testGceServiceAccounts(self):
60        scopes = ['scope1']
61        self._GetServiceCreds()
62        self._GetServiceCreds(scopes=scopes)
63        self._GetServiceCreds(service_account_name='my_service_account',
64                              scopes=scopes)
65
66
67class TestGetRunFlowFlags(unittest2.TestCase):
68
69    def setUp(self):
70        self._flags_actual = credentials_lib.FLAGS
71
72    def tearDown(self):
73        credentials_lib.FLAGS = self._flags_actual
74
75    def test_with_gflags(self):
76        HOST = 'myhostname'
77        PORT = '144169'
78
79        class MockFlags(object):
80            auth_host_name = HOST
81            auth_host_port = PORT
82            auth_local_webserver = False
83
84        credentials_lib.FLAGS = MockFlags
85        flags = credentials_lib._GetRunFlowFlags([
86            '--auth_host_name=%s' % HOST,
87            '--auth_host_port=%s' % PORT,
88            '--noauth_local_webserver',
89        ])
90        self.assertEqual(flags.auth_host_name, HOST)
91        self.assertEqual(flags.auth_host_port, PORT)
92        self.assertEqual(flags.logging_level, 'ERROR')
93        self.assertEqual(flags.noauth_local_webserver, True)
94
95    def test_without_gflags(self):
96        credentials_lib.FLAGS = None
97        flags = credentials_lib._GetRunFlowFlags([])
98        self.assertEqual(flags.auth_host_name, 'localhost')
99        self.assertEqual(flags.auth_host_port, [8080, 8090])
100        self.assertEqual(flags.logging_level, 'ERROR')
101        self.assertEqual(flags.noauth_local_webserver, False)
102