1# -*- coding: utf-8 -*-
2# Copyright 2014 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Wrapper for use in daisy-chained copies."""
16
17from collections import deque
18import os
19import threading
20import time
21
22from gslib.cloud_api import BadRequestException
23from gslib.cloud_api import CloudApi
24from gslib.util import CreateLock
25from gslib.util import TRANSFER_BUFFER_SIZE
26
27
28# This controls the amount of bytes downloaded per download request.
29# We do not buffer this many bytes in memory at a time - that is controlled by
30# DaisyChainWrapper.max_buffer_size. This is the upper bound of bytes that may
31# be unnecessarily downloaded if there is a break in the resumable upload.
32_DEFAULT_DOWNLOAD_CHUNK_SIZE = 1024*1024*100
33
34
35class BufferWrapper(object):
36  """Wraps the download file pointer to use our in-memory buffer."""
37
38  def __init__(self, daisy_chain_wrapper):
39    """Provides a buffered write interface for a file download.
40
41    Args:
42      daisy_chain_wrapper: DaisyChainWrapper instance to use for buffer and
43                           locking.
44    """
45    self.daisy_chain_wrapper = daisy_chain_wrapper
46
47  def write(self, data):  # pylint: disable=invalid-name
48    """Waits for space in the buffer, then writes data to the buffer."""
49    while True:
50      with self.daisy_chain_wrapper.lock:
51        if (self.daisy_chain_wrapper.bytes_buffered <
52            self.daisy_chain_wrapper.max_buffer_size):
53          break
54      # Buffer was full, yield thread priority so the upload can pull from it.
55      time.sleep(0)
56    data_len = len(data)
57    if data_len:
58      with self.daisy_chain_wrapper.lock:
59        self.daisy_chain_wrapper.buffer.append(data)
60        self.daisy_chain_wrapper.bytes_buffered += data_len
61
62
63class DaisyChainWrapper(object):
64  """Wrapper class for daisy-chaining a cloud download to an upload.
65
66  This class instantiates a BufferWrapper object to buffer the download into
67  memory, consuming a maximum of max_buffer_size. It implements intelligent
68  behavior around read and seek that allow for all of the operations necessary
69  to copy a file.
70
71  This class is coupled with the XML and JSON implementations in that it
72  expects that small buffers (maximum of TRANSFER_BUFFER_SIZE) in size will be
73  used.
74  """
75
76  def __init__(self, src_url, src_obj_size, gsutil_api, progress_callback=None,
77               download_chunk_size=_DEFAULT_DOWNLOAD_CHUNK_SIZE):
78    """Initializes the daisy chain wrapper.
79
80    Args:
81      src_url: Source CloudUrl to copy from.
82      src_obj_size: Size of source object.
83      gsutil_api: gsutil Cloud API to use for the copy.
84      progress_callback: Optional callback function for progress notifications
85          for the download thread. Receives calls with arguments
86          (bytes_transferred, total_size).
87      download_chunk_size: Integer number of bytes to download per
88          GetObjectMedia request. This is the upper bound of bytes that may be
89          unnecessarily downloaded if there is a break in the resumable upload.
90
91    """
92    # Current read position for the upload file pointer.
93    self.position = 0
94    self.buffer = deque()
95
96    self.bytes_buffered = 0
97    # Maximum amount of bytes in memory at a time.
98    self.max_buffer_size = 1024 * 1024  # 1 MiB
99
100    self._download_chunk_size = download_chunk_size
101
102    # We save one buffer's worth of data as a special case for boto,
103    # which seeks back one buffer and rereads to compute hashes. This is
104    # unnecessary because we can just compare cloud hash digests at the end,
105    # but it allows this to work without modfiying boto.
106    self.last_position = 0
107    self.last_data = None
108
109    # Protects buffer, position, bytes_buffered, last_position, and last_data.
110    self.lock = CreateLock()
111
112    # Protects download_exception.
113    self.download_exception_lock = CreateLock()
114
115    self.src_obj_size = src_obj_size
116    self.src_url = src_url
117
118    # This is safe to use the upload and download thread because the download
119    # thread calls only GetObjectMedia, which creates a new HTTP connection
120    # independent of gsutil_api. Thus, it will not share an HTTP connection
121    # with the upload.
122    self.gsutil_api = gsutil_api
123
124    # If self.download_thread dies due to an exception, it is saved here so
125    # that it can also be raised in the upload thread.
126    self.download_exception = None
127    self.download_thread = None
128    self.progress_callback = progress_callback
129    self.stop_download = threading.Event()
130    self.StartDownloadThread(progress_callback=self.progress_callback)
131
132  def StartDownloadThread(self, start_byte=0, progress_callback=None):
133    """Starts the download thread for the source object (from start_byte)."""
134
135    def PerformDownload(start_byte, progress_callback):
136      """Downloads the source object in chunks.
137
138      This function checks the stop_download event and exits early if it is set.
139      It should be set when there is an error during the daisy-chain upload,
140      then this function can be called again with the upload's current position
141      as start_byte.
142
143      Args:
144        start_byte: Byte from which to begin the download.
145        progress_callback: Optional callback function for progress
146            notifications. Receives calls with arguments
147            (bytes_transferred, total_size).
148      """
149      # TODO: Support resumable downloads. This would require the BufferWrapper
150      # object to support seek() and tell() which requires coordination with
151      # the upload.
152      try:
153        while start_byte + self._download_chunk_size < self.src_obj_size:
154          self.gsutil_api.GetObjectMedia(
155              self.src_url.bucket_name, self.src_url.object_name,
156              BufferWrapper(self), start_byte=start_byte,
157              end_byte=start_byte + self._download_chunk_size - 1,
158              generation=self.src_url.generation, object_size=self.src_obj_size,
159              download_strategy=CloudApi.DownloadStrategy.ONE_SHOT,
160              provider=self.src_url.scheme, progress_callback=progress_callback)
161          if self.stop_download.is_set():
162            # Download thread needs to be restarted, so exit.
163            self.stop_download.clear()
164            return
165          start_byte += self._download_chunk_size
166        self.gsutil_api.GetObjectMedia(
167            self.src_url.bucket_name, self.src_url.object_name,
168            BufferWrapper(self), start_byte=start_byte,
169            generation=self.src_url.generation, object_size=self.src_obj_size,
170            download_strategy=CloudApi.DownloadStrategy.ONE_SHOT,
171            provider=self.src_url.scheme, progress_callback=progress_callback)
172      # We catch all exceptions here because we want to store them.
173      except Exception, e:  # pylint: disable=broad-except
174        # Save the exception so that it can be seen in the upload thread.
175        with self.download_exception_lock:
176          self.download_exception = e
177          raise
178
179    # TODO: If we do gzip encoding transforms mid-transfer, this will fail.
180    self.download_thread = threading.Thread(
181        target=PerformDownload,
182        args=(start_byte, progress_callback))
183    self.download_thread.start()
184
185  def read(self, amt=None):  # pylint: disable=invalid-name
186    """Exposes a stream from the in-memory buffer to the upload."""
187    if self.position == self.src_obj_size or amt == 0:
188      # If there is no data left or 0 bytes were requested, return an empty
189      # string so callers can call still call len() and read(0).
190      return ''
191    if amt is None or amt > TRANSFER_BUFFER_SIZE:
192      raise BadRequestException(
193          'Invalid HTTP read size %s during daisy chain operation, '
194          'expected <= %s.' % (amt, TRANSFER_BUFFER_SIZE))
195
196    while True:
197      with self.lock:
198        if self.buffer:
199          break
200        with self.download_exception_lock:
201          if self.download_exception:
202            # Download thread died, so we will never recover. Raise the
203            # exception that killed it.
204            raise self.download_exception  # pylint: disable=raising-bad-type
205      # Buffer was empty, yield thread priority so the download thread can fill.
206      time.sleep(0)
207    with self.lock:
208      # TODO: Need to handle the caller requesting less than a
209      # transfer_buffer_size worth of data.
210      data = self.buffer.popleft()
211      self.last_position = self.position
212      self.last_data = data
213      data_len = len(data)
214      self.position += data_len
215      self.bytes_buffered -= data_len
216    if data_len > amt:
217      raise BadRequestException(
218          'Invalid read during daisy chain operation, got data of size '
219          '%s, expected size %s.' % (data_len, amt))
220    return data
221
222  def tell(self):  # pylint: disable=invalid-name
223    with self.lock:
224      return self.position
225
226  def seek(self, offset, whence=os.SEEK_SET):  # pylint: disable=invalid-name
227    restart_download = False
228    if whence == os.SEEK_END:
229      if offset:
230        raise IOError(
231            'Invalid seek during daisy chain operation. Non-zero offset %s '
232            'from os.SEEK_END is not supported' % offset)
233      with self.lock:
234        self.last_position = self.position
235        self.last_data = None
236        # Safe because we check position against src_obj_size in read.
237        self.position = self.src_obj_size
238    elif whence == os.SEEK_SET:
239      with self.lock:
240        if offset == self.position:
241          pass
242        elif offset == self.last_position:
243          self.position = self.last_position
244          if self.last_data:
245            # If we seek to end and then back, we won't have last_data; we'll
246            # get it on the next call to read.
247            self.buffer.appendleft(self.last_data)
248            self.bytes_buffered += len(self.last_data)
249        else:
250          # Once a download is complete, boto seeks to 0 and re-reads to
251          # compute the hash if an md5 isn't already present (for example a GCS
252          # composite object), so we have to re-download the whole object.
253          # Also, when daisy-chaining to a resumable upload, on error the
254          # service may have received any number of the bytes; the download
255          # needs to be restarted from that point.
256          restart_download = True
257
258      if restart_download:
259        self.stop_download.set()
260
261        # Consume any remaining bytes in the download thread so that
262        # the thread can exit, then restart the thread at the desired position.
263        while self.download_thread.is_alive():
264          with self.lock:
265            while self.bytes_buffered:
266              self.bytes_buffered -= len(self.buffer.popleft())
267          time.sleep(0)
268
269        with self.lock:
270          self.position = offset
271          self.buffer = deque()
272          self.bytes_buffered = 0
273          self.last_position = 0
274          self.last_data = None
275        self.StartDownloadThread(start_byte=offset,
276                                 progress_callback=self.progress_callback)
277    else:
278      raise IOError('Daisy-chain download wrapper does not support '
279                    'seek mode %s' % whence)
280
281  def seekable(self):  # pylint: disable=invalid-name
282    return True
283