1"""Tests for google3.cloud.bigscience.apitools.base.py.batch."""
2
3import textwrap
4
5import mock
6from six.moves import http_client
7from six.moves.urllib import parse
8import unittest2
9
10from apitools.base.py import batch
11from apitools.base.py import exceptions
12from apitools.base.py import http_wrapper
13
14
15class FakeCredentials(object):
16
17    def __init__(self):
18        self.num_refreshes = 0
19
20    def refresh(self, _):
21        self.num_refreshes += 1
22
23
24class FakeHttp(object):
25
26    class FakeRequest(object):
27
28        def __init__(self, credentials=None):
29            if credentials is not None:
30                self.credentials = credentials
31
32    def __init__(self, credentials=None):
33        self.request = FakeHttp.FakeRequest(credentials=credentials)
34
35
36class FakeService(object):
37
38    """A service for testing."""
39
40    def GetMethodConfig(self, _):
41        return {}
42
43    def GetUploadConfig(self, _):
44        return {}
45
46    # pylint: disable=unused-argument
47    def PrepareHttpRequest(
48            self, method_config, request, global_params, upload_config):
49        return global_params['desired_request']
50    # pylint: enable=unused-argument
51
52    def ProcessHttpResponse(self, _, http_response):
53        return http_response
54
55
56class BatchTest(unittest2.TestCase):
57
58    def assertUrlEqual(self, expected_url, provided_url):
59
60        def parse_components(url):
61            parsed = parse.urlsplit(url)
62            query = parse.parse_qs(parsed.query)
63            return parsed._replace(query=''), query
64
65        expected_parse, expected_query = parse_components(expected_url)
66        provided_parse, provided_query = parse_components(provided_url)
67
68        self.assertEqual(expected_parse, provided_parse)
69        self.assertEqual(expected_query, provided_query)
70
71    def __ConfigureMock(self, mock_request, expected_request, response):
72
73        if isinstance(response, list):
74            response = list(response)
75
76        def CheckRequest(_, request, **unused_kwds):
77            self.assertUrlEqual(expected_request.url, request.url)
78            self.assertEqual(expected_request.http_method, request.http_method)
79            if isinstance(response, list):
80                return response.pop(0)
81            else:
82                return response
83
84        mock_request.side_effect = CheckRequest
85
86    def testRequestServiceUnavailable(self):
87        mock_service = FakeService()
88
89        desired_url = 'https://www.example.com'
90        batch_api_request = batch.BatchApiRequest(batch_url=desired_url,
91                                                  retryable_codes=[])
92        # The request to be added. The actual request sent will be somewhat
93        # larger, as this is added to a batch.
94        desired_request = http_wrapper.Request(desired_url, 'POST', {
95            'content-type': 'multipart/mixed; boundary="None"',
96            'content-length': 80,
97        }, 'x' * 80)
98
99        with mock.patch.object(http_wrapper, 'MakeRequest',
100                               autospec=True) as mock_request:
101            self.__ConfigureMock(
102                mock_request,
103                http_wrapper.Request(desired_url, 'POST', {
104                    'content-type': 'multipart/mixed; boundary="None"',
105                    'content-length': 419,
106                }, 'x' * 419),
107                http_wrapper.Response({
108                    'status': '200',
109                    'content-type': 'multipart/mixed; boundary="boundary"',
110                }, textwrap.dedent("""\
111                --boundary
112                content-type: text/plain
113                content-id: <id+0>
114
115                HTTP/1.1 503 SERVICE UNAVAILABLE
116                nope
117                --boundary--"""), None))
118
119            batch_api_request.Add(
120                mock_service, 'unused', None,
121                global_params={'desired_request': desired_request})
122
123            api_request_responses = batch_api_request.Execute(
124                FakeHttp(), sleep_between_polls=0)
125
126            self.assertEqual(1, len(api_request_responses))
127
128            # Make sure we didn't retry non-retryable code 503.
129            self.assertEqual(1, mock_request.call_count)
130
131            self.assertTrue(api_request_responses[0].is_error)
132            self.assertIsNone(api_request_responses[0].response)
133            self.assertIsInstance(api_request_responses[0].exception,
134                                  exceptions.HttpError)
135
136    def testSingleRequestInBatch(self):
137        mock_service = FakeService()
138
139        desired_url = 'https://www.example.com'
140        batch_api_request = batch.BatchApiRequest(batch_url=desired_url)
141        # The request to be added. The actual request sent will be somewhat
142        # larger, as this is added to a batch.
143        desired_request = http_wrapper.Request(desired_url, 'POST', {
144            'content-type': 'multipart/mixed; boundary="None"',
145            'content-length': 80,
146        }, 'x' * 80)
147
148        with mock.patch.object(http_wrapper, 'MakeRequest',
149                               autospec=True) as mock_request:
150            self.__ConfigureMock(
151                mock_request,
152                http_wrapper.Request(desired_url, 'POST', {
153                    'content-type': 'multipart/mixed; boundary="None"',
154                    'content-length': 419,
155                }, 'x' * 419),
156                http_wrapper.Response({
157                    'status': '200',
158                    'content-type': 'multipart/mixed; boundary="boundary"',
159                }, textwrap.dedent("""\
160                --boundary
161                content-type: text/plain
162                content-id: <id+0>
163
164                HTTP/1.1 200 OK
165                content
166                --boundary--"""), None))
167
168            batch_api_request.Add(mock_service, 'unused', None, {
169                'desired_request': desired_request,
170            })
171
172            api_request_responses = batch_api_request.Execute(FakeHttp())
173
174            self.assertEqual(1, len(api_request_responses))
175            self.assertEqual(1, mock_request.call_count)
176
177            self.assertFalse(api_request_responses[0].is_error)
178
179            response = api_request_responses[0].response
180            self.assertEqual({'status': '200'}, response.info)
181            self.assertEqual('content', response.content)
182            self.assertEqual(desired_url, response.request_url)
183
184    def testRefreshOnAuthFailure(self):
185        mock_service = FakeService()
186
187        desired_url = 'https://www.example.com'
188        batch_api_request = batch.BatchApiRequest(batch_url=desired_url)
189        # The request to be added. The actual request sent will be somewhat
190        # larger, as this is added to a batch.
191        desired_request = http_wrapper.Request(desired_url, 'POST', {
192            'content-type': 'multipart/mixed; boundary="None"',
193            'content-length': 80,
194        }, 'x' * 80)
195
196        with mock.patch.object(http_wrapper, 'MakeRequest',
197                               autospec=True) as mock_request:
198            self.__ConfigureMock(
199                mock_request,
200                http_wrapper.Request(desired_url, 'POST', {
201                    'content-type': 'multipart/mixed; boundary="None"',
202                    'content-length': 419,
203                }, 'x' * 419), [
204                    http_wrapper.Response({
205                        'status': '200',
206                        'content-type': 'multipart/mixed; boundary="boundary"',
207                    }, textwrap.dedent("""\
208                    --boundary
209                    content-type: text/plain
210                    content-id: <id+0>
211
212                    HTTP/1.1 401 UNAUTHORIZED
213                    Invalid grant
214
215                    --boundary--"""), None),
216                    http_wrapper.Response({
217                        'status': '200',
218                        'content-type': 'multipart/mixed; boundary="boundary"',
219                    }, textwrap.dedent("""\
220                    --boundary
221                    content-type: text/plain
222                    content-id: <id+0>
223
224                    HTTP/1.1 200 OK
225                    content
226                    --boundary--"""), None)
227                ])
228
229            batch_api_request.Add(mock_service, 'unused', None, {
230                'desired_request': desired_request,
231            })
232
233            credentials = FakeCredentials()
234            api_request_responses = batch_api_request.Execute(
235                FakeHttp(credentials=credentials), sleep_between_polls=0)
236
237            self.assertEqual(1, len(api_request_responses))
238            self.assertEqual(2, mock_request.call_count)
239            self.assertEqual(1, credentials.num_refreshes)
240
241            self.assertFalse(api_request_responses[0].is_error)
242
243            response = api_request_responses[0].response
244            self.assertEqual({'status': '200'}, response.info)
245            self.assertEqual('content', response.content)
246            self.assertEqual(desired_url, response.request_url)
247
248    def testNoAttempts(self):
249        desired_url = 'https://www.example.com'
250        batch_api_request = batch.BatchApiRequest(batch_url=desired_url)
251        batch_api_request.Add(FakeService(), 'unused', None, {
252            'desired_request': http_wrapper.Request(desired_url, 'POST', {
253                'content-type': 'multipart/mixed; boundary="None"',
254                'content-length': 80,
255            }, 'x' * 80),
256        })
257        api_request_responses = batch_api_request.Execute(None, max_retries=0)
258        self.assertEqual(1, len(api_request_responses))
259        self.assertIsNone(api_request_responses[0].response)
260        self.assertIsNone(api_request_responses[0].exception)
261
262    def _DoTestConvertIdToHeader(self, test_id, expected_result):
263        batch_request = batch.BatchHttpRequest('https://www.example.com')
264        self.assertEqual(
265            expected_result % batch_request._BatchHttpRequest__base_id,
266            batch_request._ConvertIdToHeader(test_id))
267
268    def testConvertIdSimple(self):
269        self._DoTestConvertIdToHeader('blah', '<%s+blah>')
270
271    def testConvertIdThatNeedsEscaping(self):
272        self._DoTestConvertIdToHeader('~tilde1', '<%s+%%7Etilde1>')
273
274    def _DoTestConvertHeaderToId(self, header, expected_id):
275        batch_request = batch.BatchHttpRequest('https://www.example.com')
276        self.assertEqual(expected_id,
277                         batch_request._ConvertHeaderToId(header))
278
279    def testConvertHeaderToIdSimple(self):
280        self._DoTestConvertHeaderToId('<hello+blah>', 'blah')
281
282    def testConvertHeaderToIdWithLotsOfPlus(self):
283        self._DoTestConvertHeaderToId('<a+++++plus>', 'plus')
284
285    def _DoTestConvertInvalidHeaderToId(self, invalid_header):
286        batch_request = batch.BatchHttpRequest('https://www.example.com')
287        self.assertRaises(exceptions.BatchError,
288                          batch_request._ConvertHeaderToId, invalid_header)
289
290    def testHeaderWithoutAngleBrackets(self):
291        self._DoTestConvertInvalidHeaderToId('1+1')
292
293    def testHeaderWithoutPlus(self):
294        self._DoTestConvertInvalidHeaderToId('<HEADER>')
295
296    def testSerializeRequest(self):
297        request = http_wrapper.Request(body='Hello World', headers={
298            'content-type': 'protocol/version',
299        })
300        expected_serialized_request = '\n'.join([
301            'GET  HTTP/1.1',
302            'Content-Type: protocol/version',
303            'MIME-Version: 1.0',
304            'content-length: 11',
305            'Host: ',
306            '',
307            'Hello World',
308        ])
309        batch_request = batch.BatchHttpRequest('https://www.example.com')
310        self.assertEqual(expected_serialized_request,
311                         batch_request._SerializeRequest(request))
312
313    def testSerializeRequestPreservesHeaders(self):
314        # Now confirm that if an additional, arbitrary header is added
315        # that it is successfully serialized to the request. Merely
316        # check that it is included, because the order of the headers
317        # in the request is arbitrary.
318        request = http_wrapper.Request(body='Hello World', headers={
319            'content-type': 'protocol/version',
320            'key': 'value',
321        })
322        batch_request = batch.BatchHttpRequest('https://www.example.com')
323        self.assertTrue(
324            'key: value\n' in batch_request._SerializeRequest(request))
325
326    def testSerializeRequestNoBody(self):
327        request = http_wrapper.Request(body=None, headers={
328            'content-type': 'protocol/version',
329        })
330        expected_serialized_request = '\n'.join([
331            'GET  HTTP/1.1',
332            'Content-Type: protocol/version',
333            'MIME-Version: 1.0',
334            'Host: ',
335            '',
336            '',
337        ])
338        batch_request = batch.BatchHttpRequest('https://www.example.com')
339        self.assertEqual(expected_serialized_request,
340                         batch_request._SerializeRequest(request))
341
342    def testDeserializeRequest(self):
343        serialized_payload = '\n'.join([
344            'GET  HTTP/1.1',
345            'Content-Type: protocol/version',
346            'MIME-Version: 1.0',
347            'content-length: 11',
348            'key: value',
349            'Host: ',
350            '',
351            'Hello World',
352        ])
353        example_url = 'https://www.example.com'
354        expected_response = http_wrapper.Response({
355            'content-length': str(len('Hello World')),
356            'Content-Type': 'protocol/version',
357            'key': 'value',
358            'MIME-Version': '1.0',
359            'status': '',
360            'Host': ''
361        }, 'Hello World', example_url)
362
363        batch_request = batch.BatchHttpRequest(example_url)
364        self.assertEqual(
365            expected_response,
366            batch_request._DeserializeResponse(serialized_payload))
367
368    def testNewId(self):
369        batch_request = batch.BatchHttpRequest('https://www.example.com')
370
371        for i in range(100):
372            self.assertEqual(str(i), batch_request._NewId())
373
374    def testAdd(self):
375        batch_request = batch.BatchHttpRequest('https://www.example.com')
376
377        for x in range(100):
378            batch_request.Add(http_wrapper.Request(body=str(x)))
379
380        for key in batch_request._BatchHttpRequest__request_response_handlers:
381            value = batch_request._BatchHttpRequest__request_response_handlers[
382                key]
383            self.assertEqual(key, value.request.body)
384            self.assertFalse(value.request.url)
385            self.assertEqual('GET', value.request.http_method)
386            self.assertIsNone(value.response)
387            self.assertIsNone(value.handler)
388
389    def testInternalExecuteWithFailedRequest(self):
390        with mock.patch.object(http_wrapper, 'MakeRequest',
391                               autospec=True) as mock_request:
392            self.__ConfigureMock(
393                mock_request,
394                http_wrapper.Request('https://www.example.com', 'POST', {
395                    'content-type': 'multipart/mixed; boundary="None"',
396                    'content-length': 80,
397                }, 'x' * 80),
398                http_wrapper.Response({'status': '300'}, None, None))
399
400            batch_request = batch.BatchHttpRequest('https://www.example.com')
401
402            self.assertRaises(
403                exceptions.HttpError, batch_request._Execute, None)
404
405    def testInternalExecuteWithNonMultipartResponse(self):
406        with mock.patch.object(http_wrapper, 'MakeRequest',
407                               autospec=True) as mock_request:
408            self.__ConfigureMock(
409                mock_request,
410                http_wrapper.Request('https://www.example.com', 'POST', {
411                    'content-type': 'multipart/mixed; boundary="None"',
412                    'content-length': 80,
413                }, 'x' * 80),
414                http_wrapper.Response({
415                    'status': '200',
416                    'content-type': 'blah/blah'
417                }, '', None))
418
419            batch_request = batch.BatchHttpRequest('https://www.example.com')
420
421            self.assertRaises(
422                exceptions.BatchError, batch_request._Execute, None)
423
424    def testInternalExecute(self):
425        with mock.patch.object(http_wrapper, 'MakeRequest',
426                               autospec=True) as mock_request:
427            self.__ConfigureMock(
428                mock_request,
429                http_wrapper.Request('https://www.example.com', 'POST', {
430                    'content-type': 'multipart/mixed; boundary="None"',
431                    'content-length': 583,
432                }, 'x' * 583),
433                http_wrapper.Response({
434                    'status': '200',
435                    'content-type': 'multipart/mixed; boundary="boundary"',
436                }, textwrap.dedent("""\
437                --boundary
438                content-type: text/plain
439                content-id: <id+2>
440
441                HTTP/1.1 200 OK
442                Second response
443
444                --boundary
445                content-type: text/plain
446                content-id: <id+1>
447
448                HTTP/1.1 401 UNAUTHORIZED
449                First response
450
451                --boundary--"""), None))
452
453            test_requests = {
454                '1': batch.RequestResponseAndHandler(
455                    http_wrapper.Request(body='first'), None, None),
456                '2': batch.RequestResponseAndHandler(
457                    http_wrapper.Request(body='second'), None, None),
458            }
459
460            batch_request = batch.BatchHttpRequest('https://www.example.com')
461            batch_request._BatchHttpRequest__request_response_handlers = (
462                test_requests)
463
464            batch_request._Execute(FakeHttp())
465
466            test_responses = (
467                batch_request._BatchHttpRequest__request_response_handlers)
468
469            self.assertEqual(http_client.UNAUTHORIZED,
470                             test_responses['1'].response.status_code)
471            self.assertEqual(http_client.OK,
472                             test_responses['2'].response.status_code)
473
474            self.assertIn(
475                'First response', test_responses['1'].response.content)
476            self.assertIn(
477                'Second response', test_responses['2'].response.content)
478
479    def testPublicExecute(self):
480
481        def LocalCallback(response, exception):
482            self.assertEqual({'status': '418'}, response.info)
483            self.assertEqual('Teapot', response.content)
484            self.assertIsNone(response.request_url)
485            self.assertIsInstance(exception, exceptions.HttpError)
486
487        global_callback = mock.Mock()
488        batch_request = batch.BatchHttpRequest(
489            'https://www.example.com', global_callback)
490
491        with mock.patch.object(batch.BatchHttpRequest, '_Execute',
492                               autospec=True) as mock_execute:
493            mock_execute.return_value = None
494
495            test_requests = {
496                '0': batch.RequestResponseAndHandler(
497                    None,
498                    http_wrapper.Response({'status': '200'}, 'Hello!', None),
499                    None),
500                '1': batch.RequestResponseAndHandler(
501                    None,
502                    http_wrapper.Response({'status': '418'}, 'Teapot', None),
503                    LocalCallback),
504            }
505
506            batch_request._BatchHttpRequest__request_response_handlers = (
507                test_requests)
508            batch_request.Execute(None)
509
510            # Global callback was called once per handler.
511            self.assertEqual(len(test_requests), global_callback.call_count)
512