1#!/usr/bin/env python
2# Copyright 2010 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Output writers for MapReduce."""
17
18from __future__ import with_statement
19
20
21
22__all__ = [
23    "GoogleCloudStorageConsistentOutputWriter",
24    "GoogleCloudStorageConsistentRecordOutputWriter",
25    "GoogleCloudStorageKeyValueOutputWriter",
26    "GoogleCloudStorageOutputWriter",
27    "GoogleCloudStorageRecordOutputWriter",
28    "COUNTER_IO_WRITE_BYTES",
29    "COUNTER_IO_WRITE_MSEC",
30    "OutputWriter",
31    "GCSRecordsPool"
32    ]
33
34# pylint: disable=g-bad-name
35# pylint: disable=protected-access
36
37import cStringIO
38import gc
39import logging
40import pickle
41import random
42import string
43import time
44
45from mapreduce import context
46from mapreduce import errors
47from mapreduce import json_util
48from mapreduce import kv_pb
49from mapreduce import model
50from mapreduce import operation
51from mapreduce import records
52from mapreduce import shard_life_cycle
53
54# pylint: disable=g-import-not-at-top
55# TODO(user): Cleanup imports if/when cloudstorage becomes part of runtime.
56try:
57  # Check if the full cloudstorage package exists. The stub part is in runtime.
58  cloudstorage = None
59  import cloudstorage
60  if hasattr(cloudstorage, "_STUB"):
61    cloudstorage = None
62  # "if" is needed because apphosting/ext/datastore_admin:main_test fails.
63  if cloudstorage:
64    from cloudstorage import cloudstorage_api
65    from cloudstorage import errors as cloud_errors
66except ImportError:
67  pass  # CloudStorage library not available
68
69# Attempt to load cloudstorage from the bundle (availble in some tests)
70if cloudstorage is None:
71  try:
72    import cloudstorage
73    from cloudstorage import cloudstorage_api
74  except ImportError:
75    pass  # CloudStorage library really not available
76
77
78# Counter name for number of bytes written.
79COUNTER_IO_WRITE_BYTES = "io-write-bytes"
80
81# Counter name for time spent writing data in msec
82COUNTER_IO_WRITE_MSEC = "io-write-msec"
83
84
85class OutputWriter(json_util.JsonMixin):
86  """Abstract base class for output writers.
87
88  Output writers process all mapper handler output, which is not
89  the operation.
90
91  OutputWriter's lifecycle is the following:
92    0) validate called to validate mapper specification.
93    1) init_job is called to initialize any job-level state.
94    2) create() is called, which should create a new instance of output
95       writer for a given shard
96    3) from_json()/to_json() are used to persist writer's state across
97       multiple slices.
98    4) write() method is called to write data.
99    5) finalize() is called when shard processing is done.
100    6) finalize_job() is called when job is completed.
101    7) get_filenames() is called to get output file names.
102  """
103
104  @classmethod
105  def validate(cls, mapper_spec):
106    """Validates mapper specification.
107
108    Output writer parameters are expected to be passed as "output_writer"
109    subdictionary of mapper_spec.params. To be compatible with previous
110    API output writer is advised to check mapper_spec.params and issue
111    a warning if "output_writer" subdicationary is not present.
112    _get_params helper method can be used to simplify implementation.
113
114    Args:
115      mapper_spec: an instance of model.MapperSpec to validate.
116    """
117    raise NotImplementedError("validate() not implemented in %s" % cls)
118
119  @classmethod
120  def init_job(cls, mapreduce_state):
121    """Initialize job-level writer state.
122
123    This method is only to support the deprecated feature which is shared
124    output files by many shards. New output writers should not do anything
125    in this method.
126
127    Args:
128      mapreduce_state: an instance of model.MapreduceState describing current
129      job. MapreduceState.writer_state can be modified during initialization
130      to save the information about the files shared by many shards.
131    """
132    pass
133
134  @classmethod
135  def finalize_job(cls, mapreduce_state):
136    """Finalize job-level writer state.
137
138    This method is only to support the deprecated feature which is shared
139    output files by many shards. New output writers should not do anything
140    in this method.
141
142    This method should only be called when mapreduce_state.result_status shows
143    success. After finalizing the outputs, it should save the info for shard
144    shared files into mapreduce_state.writer_state so that other operations
145    can find the outputs.
146
147    Args:
148      mapreduce_state: an instance of model.MapreduceState describing current
149      job. MapreduceState.writer_state can be modified during finalization.
150    """
151    pass
152
153  @classmethod
154  def from_json(cls, state):
155    """Creates an instance of the OutputWriter for the given json state.
156
157    Args:
158      state: The OutputWriter state as a dict-like object.
159
160    Returns:
161      An instance of the OutputWriter configured using the values of json.
162    """
163    raise NotImplementedError("from_json() not implemented in %s" % cls)
164
165  def to_json(self):
166    """Returns writer state to serialize in json.
167
168    Returns:
169      A json-izable version of the OutputWriter state.
170    """
171    raise NotImplementedError("to_json() not implemented in %s" %
172                              self.__class__)
173
174  @classmethod
175  def create(cls, mr_spec, shard_number, shard_attempt, _writer_state=None):
176    """Create new writer for a shard.
177
178    Args:
179      mr_spec: an instance of model.MapreduceSpec describing current job.
180      shard_number: int shard number.
181      shard_attempt: int shard attempt.
182      _writer_state: deprecated. This is for old writers that share file
183        across shards. For new writers, each shard must have its own
184        dedicated outputs. Output state should be contained in
185        the output writer instance. The serialized output writer
186        instance will be saved by mapreduce across slices.
187    """
188    raise NotImplementedError("create() not implemented in %s" % cls)
189
190  def write(self, data):
191    """Write data.
192
193    Args:
194      data: actual data yielded from handler. Type is writer-specific.
195    """
196    raise NotImplementedError("write() not implemented in %s" %
197                              self.__class__)
198
199  def finalize(self, ctx, shard_state):
200    """Finalize writer shard-level state.
201
202    This should only be called when shard_state.result_status shows success.
203    After finalizing the outputs, it should save per-shard output file info
204    into shard_state.writer_state so that other operations can find the
205    outputs.
206
207    Args:
208      ctx: an instance of context.Context.
209      shard_state: shard state. ShardState.writer_state can be modified.
210    """
211    raise NotImplementedError("finalize() not implemented in %s" %
212                              self.__class__)
213
214  @classmethod
215  def get_filenames(cls, mapreduce_state):
216    """Obtain output filenames from mapreduce state.
217
218    This method should only be called when a MR is finished. Implementors of
219    this method should not assume any other methods of this class have been
220    called. In the case of no input data, no other method except validate
221    would have been called.
222
223    Args:
224      mapreduce_state: an instance of model.MapreduceState
225
226    Returns:
227      List of filenames this mapreduce successfully wrote to. The list can be
228    empty if no output file was successfully written.
229    """
230    raise NotImplementedError("get_filenames() not implemented in %s" % cls)
231
232  # pylint: disable=unused-argument
233  def _supports_shard_retry(self, tstate):
234    """Whether this output writer instance supports shard retry.
235
236    Args:
237      tstate: model.TransientShardState for current shard.
238
239    Returns:
240      boolean. Whether this output writer instance supports shard retry.
241    """
242    return False
243
244  def _supports_slice_recovery(self, mapper_spec):
245    """Whether this output writer supports slice recovery.
246
247    Args:
248      mapper_spec: instance of model.MapperSpec.
249
250    Returns:
251      boolean. Whether this output writer instance supports slice recovery.
252    """
253    return False
254
255  # pylint: disable=unused-argument
256  def _recover(self, mr_spec, shard_number, shard_attempt):
257    """Create a new output writer instance from the old one.
258
259    This method is called when _supports_slice_recovery returns True,
260    and when there is a chance the old output writer instance is out of sync
261    with its storage medium due to a retry of a slice. _recover should
262    create a new instance based on the old one. When finalize is called
263    on the new instance, it could combine valid outputs from all instances
264    to generate the final output. How the new instance maintains references
265    to previous outputs is up to implementation.
266
267    Any exception during recovery is subject to normal slice/shard retry.
268    So recovery logic must be idempotent.
269
270    Args:
271      mr_spec: an instance of model.MapreduceSpec describing current job.
272      shard_number: int shard number.
273      shard_attempt: int shard attempt.
274
275    Returns:
276      a new instance of output writer.
277    """
278    raise NotImplementedError()
279
280
281# Flush size for files api write requests. Approximately one block of data.
282_FILE_POOL_FLUSH_SIZE = 128*1024
283
284# Maximum size of files api request. Slightly less than 1M.
285_FILE_POOL_MAX_SIZE = 1000*1024
286
287
288def _get_params(mapper_spec, allowed_keys=None, allow_old=True):
289  """Obtain output writer parameters.
290
291  Utility function for output writer implementation. Fetches parameters
292  from mapreduce specification giving appropriate usage warnings.
293
294  Args:
295    mapper_spec: The MapperSpec for the job
296    allowed_keys: set of all allowed keys in parameters as strings. If it is not
297      None, then parameters are expected to be in a separate "output_writer"
298      subdictionary of mapper_spec parameters.
299    allow_old: Allow parameters to exist outside of the output_writer
300      subdictionary for compatability.
301
302  Returns:
303    mapper parameters as dict
304
305  Raises:
306    BadWriterParamsError: if parameters are invalid/missing or not allowed.
307  """
308  if "output_writer" not in mapper_spec.params:
309    message = (
310        "Output writer's parameters should be specified in "
311        "output_writer subdictionary.")
312    if not allow_old or allowed_keys:
313      raise errors.BadWriterParamsError(message)
314    params = mapper_spec.params
315    params = dict((str(n), v) for n, v in params.iteritems())
316  else:
317    if not isinstance(mapper_spec.params.get("output_writer"), dict):
318      raise errors.BadWriterParamsError(
319          "Output writer parameters should be a dictionary")
320    params = mapper_spec.params.get("output_writer")
321    params = dict((str(n), v) for n, v in params.iteritems())
322    if allowed_keys:
323      params_diff = set(params.keys()) - allowed_keys
324      if params_diff:
325        raise errors.BadWriterParamsError(
326            "Invalid output_writer parameters: %s" % ",".join(params_diff))
327  return params
328
329
330class _RecordsPoolBase(context.Pool):
331  """Base class for Pool of append operations for records files."""
332
333  # Approximate number of bytes of overhead for storing one record.
334  _RECORD_OVERHEAD_BYTES = 10
335
336  def __init__(self,
337               flush_size_chars=_FILE_POOL_FLUSH_SIZE,
338               ctx=None,
339               exclusive=False):
340    """Constructor.
341
342    Any classes that subclass this will need to implement the _write() function.
343
344    Args:
345      flush_size_chars: buffer flush threshold as int.
346      ctx: mapreduce context as context.Context.
347      exclusive: a boolean flag indicating if the pool has an exclusive
348        access to the file. If it is True, then it's possible to write
349        bigger chunks of data.
350    """
351    self._flush_size = flush_size_chars
352    self._buffer = []
353    self._size = 0
354    self._ctx = ctx
355    self._exclusive = exclusive
356
357  def append(self, data):
358    """Append data to a file."""
359    data_length = len(data)
360    if self._size + data_length > self._flush_size:
361      self.flush()
362
363    if not self._exclusive and data_length > _FILE_POOL_MAX_SIZE:
364      raise errors.Error(
365          "Too big input %s (%s)."  % (data_length, _FILE_POOL_MAX_SIZE))
366    else:
367      self._buffer.append(data)
368      self._size += data_length
369
370    if self._size > self._flush_size:
371      self.flush()
372
373  def flush(self):
374    """Flush pool contents."""
375    # Write data to in-memory buffer first.
376    buf = cStringIO.StringIO()
377    with records.RecordsWriter(buf) as w:
378      for record in self._buffer:
379        w.write(record)
380      w._pad_block()
381    str_buf = buf.getvalue()
382    buf.close()
383
384    if not self._exclusive and len(str_buf) > _FILE_POOL_MAX_SIZE:
385      # Shouldn't really happen because of flush size.
386      raise errors.Error(
387          "Buffer too big. Can't write more than %s bytes in one request: "
388          "risk of writes interleaving. Got: %s" %
389          (_FILE_POOL_MAX_SIZE, len(str_buf)))
390
391    # Write data to file.
392    start_time = time.time()
393    self._write(str_buf)
394    if self._ctx:
395      operation.counters.Increment(
396          COUNTER_IO_WRITE_BYTES, len(str_buf))(self._ctx)
397      operation.counters.Increment(
398          COUNTER_IO_WRITE_MSEC,
399          int((time.time() - start_time) * 1000))(self._ctx)
400
401    # reset buffer
402    self._buffer = []
403    self._size = 0
404    gc.collect()
405
406  def _write(self, str_buf):
407    raise NotImplementedError("_write() not implemented in %s" % type(self))
408
409  def __enter__(self):
410    return self
411
412  def __exit__(self, atype, value, traceback):
413    self.flush()
414
415
416class GCSRecordsPool(_RecordsPoolBase):
417  """Pool of append operations for records using GCS."""
418
419  # GCS writes in 256K blocks.
420  _GCS_BLOCK_SIZE = 256 * 1024  # 256K
421
422  def __init__(self,
423               filehandle,
424               flush_size_chars=_FILE_POOL_FLUSH_SIZE,
425               ctx=None,
426               exclusive=False):
427    """Requires the filehandle of an open GCS file to write to."""
428    super(GCSRecordsPool, self).__init__(flush_size_chars, ctx, exclusive)
429    self._filehandle = filehandle
430    self._buf_size = 0
431
432  def _write(self, str_buf):
433    """Uses the filehandle to the file in GCS to write to it."""
434    self._filehandle.write(str_buf)
435    self._buf_size += len(str_buf)
436
437  def flush(self, force=False):
438    """Flush pool contents.
439
440    Args:
441      force: Inserts additional padding to achieve the minimum block size
442        required for GCS.
443    """
444    super(GCSRecordsPool, self).flush()
445    if force:
446      extra_padding = self._buf_size % self._GCS_BLOCK_SIZE
447      if extra_padding > 0:
448        self._write("\x00" * (self._GCS_BLOCK_SIZE - extra_padding))
449    self._filehandle.flush()
450
451
452class _GoogleCloudStorageBase(shard_life_cycle._ShardLifeCycle,
453                              OutputWriter):
454  """Base abstract class for all GCS writers.
455
456  Required configuration in the mapper_spec.output_writer dictionary.
457    BUCKET_NAME_PARAM: name of the bucket to use (with no extra delimiters or
458      suffixes such as directories. Directories/prefixes can be specifed as
459      part of the NAMING_FORMAT_PARAM).
460
461  Optional configuration in the mapper_spec.output_writer dictionary:
462    ACL_PARAM: acl to apply to new files, else bucket default used.
463    NAMING_FORMAT_PARAM: prefix format string for the new files (there is no
464      required starting slash, expected formats would look like
465      "directory/basename...", any starting slash will be treated as part of
466      the file name) that should use the following substitutions:
467        $name - the name of the job
468        $id - the id assigned to the job
469        $num - the shard number
470      If there is more than one shard $num must be used. An arbitrary suffix may
471      be applied by the writer.
472    CONTENT_TYPE_PARAM: mime type to apply on the files. If not provided, Google
473      Cloud Storage will apply its default.
474    TMP_BUCKET_NAME_PARAM: name of the bucket used for writing tmp files by
475      consistent GCS output writers. Defaults to BUCKET_NAME_PARAM if not set.
476  """
477
478  BUCKET_NAME_PARAM = "bucket_name"
479  TMP_BUCKET_NAME_PARAM = "tmp_bucket_name"
480  ACL_PARAM = "acl"
481  NAMING_FORMAT_PARAM = "naming_format"
482  CONTENT_TYPE_PARAM = "content_type"
483
484  # Internal parameter.
485  _ACCOUNT_ID_PARAM = "account_id"
486  _TMP_ACCOUNT_ID_PARAM = "tmp_account_id"
487
488  @classmethod
489  def _get_gcs_bucket(cls, writer_spec):
490    return writer_spec[cls.BUCKET_NAME_PARAM]
491
492  @classmethod
493  def _get_account_id(cls, writer_spec):
494    return writer_spec.get(cls._ACCOUNT_ID_PARAM, None)
495
496  @classmethod
497  def _get_tmp_gcs_bucket(cls, writer_spec):
498    """Returns bucket used for writing tmp files."""
499    if cls.TMP_BUCKET_NAME_PARAM in writer_spec:
500      return writer_spec[cls.TMP_BUCKET_NAME_PARAM]
501    return cls._get_gcs_bucket(writer_spec)
502
503  @classmethod
504  def _get_tmp_account_id(cls, writer_spec):
505    """Returns the account id to use with tmp bucket."""
506    # pick tmp id iff tmp bucket is set explicitly
507    if cls.TMP_BUCKET_NAME_PARAM in writer_spec:
508      return writer_spec.get(cls._TMP_ACCOUNT_ID_PARAM, None)
509    return cls._get_account_id(writer_spec)
510
511
512class _GoogleCloudStorageOutputWriterBase(_GoogleCloudStorageBase):
513  """Base class for GCS writers directly interacting with GCS.
514
515  Base class for both _GoogleCloudStorageOutputWriter and
516  GoogleCloudStorageConsistentOutputWriter.
517
518  This class is expected to be subclassed with a writer that applies formatting
519  to user-level records.
520
521  Subclasses need to define to_json, from_json, create, finalize and
522  _get_write_buffer methods.
523
524  See _GoogleCloudStorageBase for config options.
525  """
526
527  # Default settings
528  _DEFAULT_NAMING_FORMAT = "$name/$id/output-$num"
529
530  # Internal parameters
531  _MR_TMP = "gae_mr_tmp"
532  _TMP_FILE_NAMING_FORMAT = (
533      _MR_TMP + "/$name/$id/attempt-$attempt/output-$num/seg-$seg")
534
535  @classmethod
536  def _generate_filename(cls, writer_spec, name, job_id, num,
537                         attempt=None, seg_index=None):
538    """Generates a filename for a particular output.
539
540    Args:
541      writer_spec: specification dictionary for the output writer.
542      name: name of the job.
543      job_id: the ID number assigned to the job.
544      num: shard number.
545      attempt: the shard attempt number.
546      seg_index: index of the seg. None means the final output.
547
548    Returns:
549      a string containing the filename.
550
551    Raises:
552      BadWriterParamsError: if the template contains any errors such as invalid
553        syntax or contains unknown substitution placeholders.
554    """
555    naming_format = cls._TMP_FILE_NAMING_FORMAT
556    if seg_index is None:
557      naming_format = writer_spec.get(cls.NAMING_FORMAT_PARAM,
558                                      cls._DEFAULT_NAMING_FORMAT)
559
560    template = string.Template(naming_format)
561    try:
562      # Check that template doesn't use undefined mappings and is formatted well
563      if seg_index is None:
564        return template.substitute(name=name, id=job_id, num=num)
565      else:
566        return template.substitute(name=name, id=job_id, num=num,
567                                   attempt=attempt,
568                                   seg=seg_index)
569    except ValueError, error:
570      raise errors.BadWriterParamsError("Naming template is bad, %s" % (error))
571    except KeyError, error:
572      raise errors.BadWriterParamsError("Naming template '%s' has extra "
573                                        "mappings, %s" % (naming_format, error))
574
575  @classmethod
576  def get_params(cls, mapper_spec, allowed_keys=None, allow_old=True):
577    params = _get_params(mapper_spec, allowed_keys, allow_old)
578    # Use the bucket_name defined in mapper_spec params if one was not defined
579    # specifically in the output_writer params.
580    if (mapper_spec.params.get(cls.BUCKET_NAME_PARAM) is not None and
581        params.get(cls.BUCKET_NAME_PARAM) is None):
582      params[cls.BUCKET_NAME_PARAM] = mapper_spec.params[cls.BUCKET_NAME_PARAM]
583    return params
584
585  @classmethod
586  def validate(cls, mapper_spec):
587    """Validate mapper specification.
588
589    Args:
590      mapper_spec: an instance of model.MapperSpec.
591
592    Raises:
593      BadWriterParamsError: if the specification is invalid for any reason such
594        as missing the bucket name or providing an invalid bucket name.
595    """
596    writer_spec = cls.get_params(mapper_spec, allow_old=False)
597
598    # Bucket Name is required
599    if cls.BUCKET_NAME_PARAM not in writer_spec:
600      raise errors.BadWriterParamsError(
601          "%s is required for Google Cloud Storage" %
602          cls.BUCKET_NAME_PARAM)
603    try:
604      cloudstorage.validate_bucket_name(
605          writer_spec[cls.BUCKET_NAME_PARAM])
606    except ValueError, error:
607      raise errors.BadWriterParamsError("Bad bucket name, %s" % (error))
608
609    # Validate the naming format does not throw any errors using dummy values
610    cls._generate_filename(writer_spec, "name", "id", 0)
611    cls._generate_filename(writer_spec, "name", "id", 0, 1, 0)
612
613  @classmethod
614  def _open_file(cls, writer_spec, filename_suffix, use_tmp_bucket=False):
615    """Opens a new gcs file for writing."""
616    if use_tmp_bucket:
617      bucket = cls._get_tmp_gcs_bucket(writer_spec)
618      account_id = cls._get_tmp_account_id(writer_spec)
619    else:
620      bucket = cls._get_gcs_bucket(writer_spec)
621      account_id = cls._get_account_id(writer_spec)
622
623    # GoogleCloudStorage format for filenames, Initial slash is required
624    filename = "/%s/%s" % (bucket, filename_suffix)
625
626    content_type = writer_spec.get(cls.CONTENT_TYPE_PARAM, None)
627
628    options = {}
629    if cls.ACL_PARAM in writer_spec:
630      options["x-goog-acl"] = writer_spec.get(cls.ACL_PARAM)
631
632    return cloudstorage.open(filename, mode="w", content_type=content_type,
633                             options=options, _account_id=account_id)
634
635  @classmethod
636  def _get_filename(cls, shard_state):
637    return shard_state.writer_state["filename"]
638
639  @classmethod
640  def get_filenames(cls, mapreduce_state):
641    filenames = []
642    for shard in model.ShardState.find_all_by_mapreduce_state(mapreduce_state):
643      if shard.result_status == model.ShardState.RESULT_SUCCESS:
644        filenames.append(cls._get_filename(shard))
645    return filenames
646
647  def _get_write_buffer(self):
648    """Returns a buffer to be used by the write() method."""
649    raise NotImplementedError()
650
651  def write(self, data):
652    """Write data to the GoogleCloudStorage file.
653
654    Args:
655      data: string containing the data to be written.
656    """
657    start_time = time.time()
658    self._get_write_buffer().write(data)
659    ctx = context.get()
660    operation.counters.Increment(COUNTER_IO_WRITE_BYTES, len(data))(ctx)
661    operation.counters.Increment(
662        COUNTER_IO_WRITE_MSEC, int((time.time() - start_time) * 1000))(ctx)
663
664  # pylint: disable=unused-argument
665  def _supports_shard_retry(self, tstate):
666    return True
667
668
669class _GoogleCloudStorageOutputWriter(_GoogleCloudStorageOutputWriterBase):
670  """Naive version of GoogleCloudStorageWriter.
671
672  This version is known to create inconsistent outputs if the input changes
673  during slice retries. Consider using GoogleCloudStorageConsistentOutputWriter
674  instead.
675
676  Optional configuration in the mapper_spec.output_writer dictionary:
677    _NO_DUPLICATE: if True, slice recovery logic will be used to ensure
678      output files has no duplicates. Every shard should have only one final
679      output in user specified location. But it may produce many smaller
680      files (named "seg") due to slice recovery. These segs live in a
681      tmp directory and should be combined and renamed to the final location.
682      In current impl, they are not combined.
683  """
684  _SEG_PREFIX = "seg_prefix"
685  _LAST_SEG_INDEX = "last_seg_index"
686  _JSON_GCS_BUFFER = "buffer"
687  _JSON_SEG_INDEX = "seg_index"
688  _JSON_NO_DUP = "no_dup"
689  # This can be used to store valid length with a GCS file.
690  _VALID_LENGTH = "x-goog-meta-gae-mr-valid-length"
691  _NO_DUPLICATE = "no_duplicate"
692
693  # writer_spec only used by subclasses, pylint: disable=unused-argument
694  def __init__(self, streaming_buffer, writer_spec=None):
695    """Initialize a GoogleCloudStorageOutputWriter instance.
696
697    Args:
698      streaming_buffer: an instance of writable buffer from cloudstorage_api.
699
700      writer_spec: the specification for the writer.
701    """
702    self._streaming_buffer = streaming_buffer
703    self._no_dup = False
704    if writer_spec:
705      self._no_dup = writer_spec.get(self._NO_DUPLICATE, False)
706    if self._no_dup:
707      # This is the index of the current seg, starting at 0.
708      # This number is incremented sequentially and every index
709      # represents a real seg.
710      self._seg_index = int(streaming_buffer.name.rsplit("-", 1)[1])
711      # The valid length of the current seg by the end of the previous slice.
712      # This value is updated by the end of a slice, by which time,
713      # all content before this have already been either
714      # flushed to GCS or serialized to task payload.
715      self._seg_valid_length = 0
716
717  @classmethod
718  def validate(cls, mapper_spec):
719    """Inherit docs."""
720    writer_spec = cls.get_params(mapper_spec, allow_old=False)
721    if writer_spec.get(cls._NO_DUPLICATE, False) not in (True, False):
722      raise errors.BadWriterParamsError("No duplicate must a boolean.")
723    super(_GoogleCloudStorageOutputWriter, cls).validate(mapper_spec)
724
725  def _get_write_buffer(self):
726    return self._streaming_buffer
727
728  @classmethod
729  def create(cls, mr_spec, shard_number, shard_attempt, _writer_state=None):
730    """Inherit docs."""
731    writer_spec = cls.get_params(mr_spec.mapper, allow_old=False)
732    seg_index = None
733    if writer_spec.get(cls._NO_DUPLICATE, False):
734      seg_index = 0
735
736    # Determine parameters
737    key = cls._generate_filename(writer_spec, mr_spec.name,
738                                 mr_spec.mapreduce_id,
739                                 shard_number, shard_attempt,
740                                 seg_index)
741    return cls._create(writer_spec, key)
742
743  @classmethod
744  def _create(cls, writer_spec, filename_suffix):
745    """Helper method that actually creates the file in cloud storage."""
746    writer = cls._open_file(writer_spec, filename_suffix)
747    return cls(writer, writer_spec=writer_spec)
748
749  @classmethod
750  def from_json(cls, state):
751    writer = cls(pickle.loads(state[cls._JSON_GCS_BUFFER]))
752    no_dup = state.get(cls._JSON_NO_DUP, False)
753    writer._no_dup = no_dup
754    if no_dup:
755      writer._seg_valid_length = state[cls._VALID_LENGTH]
756      writer._seg_index = state[cls._JSON_SEG_INDEX]
757    return writer
758
759  def end_slice(self, slice_ctx):
760    if not self._streaming_buffer.closed:
761      self._streaming_buffer.flush()
762
763  def to_json(self):
764    result = {self._JSON_GCS_BUFFER: pickle.dumps(self._streaming_buffer),
765              self._JSON_NO_DUP: self._no_dup}
766    if self._no_dup:
767      result.update({
768          # Save the length of what has been written, including what is
769          # buffered in memory.
770          # This assumes from_json and to_json are only called
771          # at the beginning of a slice.
772          # TODO(user): This may not be a good assumption.
773          self._VALID_LENGTH: self._streaming_buffer.tell(),
774          self._JSON_SEG_INDEX: self._seg_index})
775    return result
776
777  def finalize(self, ctx, shard_state):
778    self._streaming_buffer.close()
779
780    if self._no_dup:
781      cloudstorage_api.copy2(
782          self._streaming_buffer.name,
783          self._streaming_buffer.name,
784          metadata={self._VALID_LENGTH: self._streaming_buffer.tell()})
785
786      # The filename user requested.
787      mr_spec = ctx.mapreduce_spec
788      writer_spec = self.get_params(mr_spec.mapper, allow_old=False)
789      filename = self._generate_filename(writer_spec,
790                                         mr_spec.name,
791                                         mr_spec.mapreduce_id,
792                                         shard_state.shard_number)
793      seg_filename = self._streaming_buffer.name
794      prefix, last_index = seg_filename.rsplit("-", 1)
795      # These info is enough for any external process to combine
796      # all segs into the final file.
797      # TODO(user): Create a special input reader to combine segs.
798      shard_state.writer_state = {self._SEG_PREFIX: prefix + "-",
799                                  self._LAST_SEG_INDEX: int(last_index),
800                                  "filename": filename}
801    else:
802      shard_state.writer_state = {"filename": self._streaming_buffer.name}
803
804  def _supports_slice_recovery(self, mapper_spec):
805    writer_spec = self.get_params(mapper_spec, allow_old=False)
806    return writer_spec.get(self._NO_DUPLICATE, False)
807
808  def _recover(self, mr_spec, shard_number, shard_attempt):
809    next_seg_index = self._seg_index
810
811    # Save the current seg if it actually has something.
812    # Remember self._streaming_buffer is the pickled instance
813    # from the previous slice.
814    if self._seg_valid_length != 0:
815      try:
816        gcs_next_offset = self._streaming_buffer._get_offset_from_gcs() + 1
817        # If GCS is ahead of us, just force close.
818        if gcs_next_offset > self._streaming_buffer.tell():
819          self._streaming_buffer._force_close(gcs_next_offset)
820        # Otherwise flush in memory contents too.
821        else:
822          self._streaming_buffer.close()
823      except cloudstorage.FileClosedError:
824        pass
825      cloudstorage_api.copy2(
826          self._streaming_buffer.name,
827          self._streaming_buffer.name,
828          metadata={self._VALID_LENGTH:
829                    self._seg_valid_length})
830      next_seg_index = self._seg_index + 1
831
832    writer_spec = self.get_params(mr_spec.mapper, allow_old=False)
833    # Create name for the new seg.
834    key = self._generate_filename(
835        writer_spec, mr_spec.name,
836        mr_spec.mapreduce_id,
837        shard_number,
838        shard_attempt,
839        next_seg_index)
840    new_writer = self._create(writer_spec, key)
841    new_writer._seg_index = next_seg_index
842    return new_writer
843
844  def _get_filename_for_test(self):
845    return self._streaming_buffer.name
846
847
848GoogleCloudStorageOutputWriter = _GoogleCloudStorageOutputWriter
849
850
851class _ConsistentStatus(object):
852  """Object used to pass status to the next slice."""
853
854  def __init__(self):
855    self.writer_spec = None
856    self.mapreduce_id = None
857    self.shard = None
858    self.mainfile = None
859    self.tmpfile = None
860    self.tmpfile_1ago = None
861
862
863class GoogleCloudStorageConsistentOutputWriter(
864    _GoogleCloudStorageOutputWriterBase):
865  """Output writer to Google Cloud Storage using the cloudstorage library.
866
867  This version ensures that the output written to GCS is consistent.
868  """
869
870  # Implementation details:
871  # Each slice writes to a new tmpfile in GCS. When the slice is finished
872  # (to_json is called) the file is finalized. When slice N is started
873  # (from_json is called) it does the following:
874  # - append the contents of N-1's tmpfile to the mainfile
875  # - remove N-2's tmpfile
876  #
877  # When a slice fails the file is never finalized and will be garbage
878  # collected. It is possible for the slice to fail just after the file is
879  # finalized. We will leave a file behind in this case (we don't clean it up).
880  #
881  # Slice retries don't cause inconsitent and/or duplicate entries to be written
882  # to the mainfile (rewriting tmpfile is an idempotent operation).
883
884  _JSON_STATUS = "status"
885  _RAND_BITS = 128
886  _REWRITE_BLOCK_SIZE = 1024 * 256
887  _REWRITE_MR_TMP = "gae_mr_tmp"
888  _TMPFILE_PATTERN = _REWRITE_MR_TMP + "/$id-tmp-$shard-$random"
889  _TMPFILE_PREFIX = _REWRITE_MR_TMP + "/$id-tmp-$shard-"
890
891  def __init__(self, status):
892    """Initialize a GoogleCloudStorageConsistentOutputWriter instance.
893
894    Args:
895      status: an instance of _ConsistentStatus with initialized tmpfile
896              and mainfile.
897    """
898
899    self.status = status
900    self._data_written_to_slice = False
901
902  def _get_write_buffer(self):
903    if not self.status.tmpfile:
904      raise errors.FailJobError(
905          "write buffer called but empty, begin_slice missing?")
906    return self.status.tmpfile
907
908  def _get_filename_for_test(self):
909    return self.status.mainfile.name
910
911  @classmethod
912  def create(cls, mr_spec, shard_number, shard_attempt, _writer_state=None):
913    """Inherit docs."""
914    writer_spec = cls.get_params(mr_spec.mapper, allow_old=False)
915
916    # Determine parameters
917    key = cls._generate_filename(writer_spec, mr_spec.name,
918                                 mr_spec.mapreduce_id,
919                                 shard_number, shard_attempt)
920
921    status = _ConsistentStatus()
922    status.writer_spec = writer_spec
923    status.mainfile = cls._open_file(writer_spec, key)
924    status.mapreduce_id = mr_spec.mapreduce_id
925    status.shard = shard_number
926
927    return cls(status)
928
929  def _remove_tmpfile(self, filename, writer_spec):
930    if not filename:
931      return
932    account_id = self._get_tmp_account_id(writer_spec)
933    try:
934      cloudstorage_api.delete(filename, _account_id=account_id)
935    except cloud_errors.NotFoundError:
936      pass
937
938  def _rewrite_tmpfile(self, mainfile, tmpfile, writer_spec):
939    """Copies contents of tmpfile (name) to mainfile (buffer)."""
940    if mainfile.closed:
941      # can happen when finalize fails
942      return
943
944    account_id = self._get_tmp_account_id(writer_spec)
945    f = cloudstorage_api.open(tmpfile, _account_id=account_id)
946    # both reads and writes are buffered - the number here doesn't matter
947    data = f.read(self._REWRITE_BLOCK_SIZE)
948    while data:
949      mainfile.write(data)
950      data = f.read(self._REWRITE_BLOCK_SIZE)
951    f.close()
952    mainfile.flush()
953
954  @classmethod
955  def _create_tmpfile(cls, status):
956    """Creates a new random-named tmpfile."""
957
958    # We can't put the tmpfile in the same directory as the output. There are
959    # rare circumstances when we leave trash behind and we don't want this trash
960    # to be loaded into bigquery and/or used for restore.
961    #
962    # We used mapreduce id, shard number and attempt and 128 random bits to make
963    # collisions virtually impossible.
964    tmpl = string.Template(cls._TMPFILE_PATTERN)
965    filename = tmpl.substitute(
966        id=status.mapreduce_id, shard=status.shard,
967        random=random.getrandbits(cls._RAND_BITS))
968
969    return cls._open_file(status.writer_spec, filename, use_tmp_bucket=True)
970
971  def begin_slice(self, slice_ctx):
972    status = self.status
973    writer_spec = status.writer_spec
974
975    # we're slice N so we can safely remove N-2's tmpfile
976    if status.tmpfile_1ago:
977      self._remove_tmpfile(status.tmpfile_1ago.name, writer_spec)
978
979    # rewrite N-1's tmpfile (idempotent)
980    # N-1 file might be needed if this this slice is ever retried so we need
981    # to make sure it won't be cleaned up just yet.
982    files_to_keep = []
983    if status.tmpfile:  # does no exist on slice 0
984      self._rewrite_tmpfile(status.mainfile, status.tmpfile.name, writer_spec)
985      files_to_keep.append(status.tmpfile.name)
986
987    # clean all the garbage you can find
988    self._try_to_clean_garbage(
989        writer_spec, exclude_list=files_to_keep)
990
991    # Rotate the files in status.
992    status.tmpfile_1ago = status.tmpfile
993    status.tmpfile = self._create_tmpfile(status)
994
995    # There's a test for this condition. Not sure if this can happen.
996    if status.mainfile.closed:
997      status.tmpfile.close()
998      self._remove_tmpfile(status.tmpfile.name, writer_spec)
999
1000  @classmethod
1001  def from_json(cls, state):
1002    return cls(pickle.loads(state[cls._JSON_STATUS]))
1003
1004  def end_slice(self, slice_ctx):
1005    self.status.tmpfile.close()
1006
1007  def to_json(self):
1008    return {self._JSON_STATUS: pickle.dumps(self.status)}
1009
1010  def write(self, data):
1011    super(GoogleCloudStorageConsistentOutputWriter, self).write(data)
1012    self._data_written_to_slice = True
1013
1014  def _try_to_clean_garbage(self, writer_spec, exclude_list=()):
1015    """Tries to remove any files created by this shard that aren't needed.
1016
1017    Args:
1018      writer_spec: writer_spec for the MR.
1019      exclude_list: A list of filenames (strings) that should not be
1020        removed.
1021    """
1022    # Try to remove garbage (if any). Note that listbucket is not strongly
1023    # consistent so something might survive.
1024    tmpl = string.Template(self._TMPFILE_PREFIX)
1025    prefix = tmpl.substitute(
1026        id=self.status.mapreduce_id, shard=self.status.shard)
1027    bucket = self._get_tmp_gcs_bucket(writer_spec)
1028    account_id = self._get_tmp_account_id(writer_spec)
1029    for f in cloudstorage.listbucket("/%s/%s" % (bucket, prefix),
1030                                     _account_id=account_id):
1031      if f.filename not in exclude_list:
1032        self._remove_tmpfile(f.filename, self.status.writer_spec)
1033
1034  def finalize(self, ctx, shard_state):
1035    if self._data_written_to_slice:
1036      raise errors.FailJobError(
1037          "finalize() called after data was written")
1038
1039    if self.status.tmpfile:
1040      self.status.tmpfile.close()  # it's empty
1041    self.status.mainfile.close()
1042
1043    # rewrite happened, close happened, we can remove the tmp files
1044    if self.status.tmpfile_1ago:
1045      self._remove_tmpfile(self.status.tmpfile_1ago.name,
1046                           self.status.writer_spec)
1047    if self.status.tmpfile:
1048      self._remove_tmpfile(self.status.tmpfile.name,
1049                           self.status.writer_spec)
1050
1051    self._try_to_clean_garbage(self.status.writer_spec)
1052
1053    shard_state.writer_state = {"filename": self.status.mainfile.name}
1054
1055
1056class _GoogleCloudStorageRecordOutputWriterBase(_GoogleCloudStorageBase):
1057  """Wraps a GCS writer with a records.RecordsWriter.
1058
1059  This class wraps a WRITER_CLS (and its instance) and delegates most calls
1060  to it. write() calls are done using records.RecordsWriter.
1061
1062  WRITER_CLS has to be set to a subclass of _GoogleCloudStorageOutputWriterBase.
1063
1064  For list of supported parameters see _GoogleCloudStorageBase.
1065  """
1066
1067  WRITER_CLS = None
1068
1069  def __init__(self, writer):
1070    self._writer = writer
1071    self._record_writer = records.RecordsWriter(writer)
1072
1073  @classmethod
1074  def validate(cls, mapper_spec):
1075    return cls.WRITER_CLS.validate(mapper_spec)
1076
1077  @classmethod
1078  def init_job(cls, mapreduce_state):
1079    return cls.WRITER_CLS.init_job(mapreduce_state)
1080
1081  @classmethod
1082  def finalize_job(cls, mapreduce_state):
1083    return cls.WRITER_CLS.finalize_job(mapreduce_state)
1084
1085  @classmethod
1086  def from_json(cls, state):
1087    return cls(cls.WRITER_CLS.from_json(state))
1088
1089  def to_json(self):
1090    return self._writer.to_json()
1091
1092  @classmethod
1093  def create(cls, mr_spec, shard_number, shard_attempt, _writer_state=None):
1094    return cls(cls.WRITER_CLS.create(mr_spec, shard_number, shard_attempt,
1095                                     _writer_state))
1096
1097  def write(self, data):
1098    self._record_writer.write(data)
1099
1100  def finalize(self, ctx, shard_state):
1101    return self._writer.finalize(ctx, shard_state)
1102
1103  @classmethod
1104  def get_filenames(cls, mapreduce_state):
1105    return cls.WRITER_CLS.get_filenames(mapreduce_state)
1106
1107  def _supports_shard_retry(self, tstate):
1108    return self._writer._supports_shard_retry(tstate)
1109
1110  def _supports_slice_recovery(self, mapper_spec):
1111    return self._writer._supports_slice_recovery(mapper_spec)
1112
1113  def _recover(self, mr_spec, shard_number, shard_attempt):
1114    return self._writer._recover(mr_spec, shard_number, shard_attempt)
1115
1116  def begin_slice(self, slice_ctx):
1117    return self._writer.begin_slice(slice_ctx)
1118
1119  def end_slice(self, slice_ctx):
1120    # Pad if this is not the end_slice call after finalization.
1121    if not self._writer._get_write_buffer().closed:
1122      self._record_writer._pad_block()
1123    return self._writer.end_slice(slice_ctx)
1124
1125
1126class _GoogleCloudStorageRecordOutputWriter(
1127    _GoogleCloudStorageRecordOutputWriterBase):
1128  WRITER_CLS = _GoogleCloudStorageOutputWriter
1129
1130
1131GoogleCloudStorageRecordOutputWriter = _GoogleCloudStorageRecordOutputWriter
1132
1133
1134class GoogleCloudStorageConsistentRecordOutputWriter(
1135    _GoogleCloudStorageRecordOutputWriterBase):
1136  WRITER_CLS = GoogleCloudStorageConsistentOutputWriter
1137
1138
1139# TODO(user): Write a test for this.
1140class _GoogleCloudStorageKeyValueOutputWriter(
1141    _GoogleCloudStorageRecordOutputWriter):
1142  """Write key/values to Google Cloud Storage files in LevelDB format."""
1143
1144  def write(self, data):
1145    if len(data) != 2:
1146      logging.error("Got bad tuple of length %d (2-tuple expected): %s",
1147                    len(data), data)
1148
1149    try:
1150      key = str(data[0])
1151      value = str(data[1])
1152    except TypeError:
1153      logging.error("Expecting a tuple, but got %s: %s",
1154                    data.__class__.__name__, data)
1155
1156    proto = kv_pb.KeyValue()
1157    proto.set_key(key)
1158    proto.set_value(value)
1159    GoogleCloudStorageRecordOutputWriter.write(self, proto.Encode())
1160
1161
1162GoogleCloudStorageKeyValueOutputWriter = _GoogleCloudStorageKeyValueOutputWriter
1163