model.py revision 4a4f2fe02baf385f6c24fc98c6e17bf6ac5e0724
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Model classes which are used to communicate between parts of implementation.
18
19These model classes are describing mapreduce, its current state and
20communication messages. They are either stored in the datastore or
21serialized to/from json and passed around with other means.
22"""
23
24# Disable "Invalid method name"
25# pylint: disable=g-bad-name
26
27
28
29__all__ = ["MapreduceState",
30           "MapperSpec",
31           "MapreduceControl",
32           "MapreduceSpec",
33           "ShardState",
34           "CountersMap",
35           "TransientShardState",
36           "QuerySpec",
37           "HugeTask"]
38
39import cgi
40import datetime
41import urllib
42import zlib
43from graphy import bar_chart
44from graphy.backends import google_chart_api
45
46try:
47  import json
48except ImportError:
49  import simplejson as json
50
51from google.appengine.api import memcache
52from google.appengine.api import taskqueue
53from google.appengine.datastore import datastore_rpc
54from google.appengine.ext import db
55from mapreduce import context
56from mapreduce import hooks
57from mapreduce import json_util
58from mapreduce import util
59
60
61# pylint: disable=protected-access
62
63
64# Special datastore kinds for MR.
65_MAP_REDUCE_KINDS = ("_AE_MR_MapreduceControl",
66                     "_AE_MR_MapreduceState",
67                     "_AE_MR_ShardState",
68                     "_AE_MR_TaskPayload")
69
70
71class _HugeTaskPayload(db.Model):
72  """Model object to store task payload."""
73
74  payload = db.BlobProperty()
75
76  @classmethod
77  def kind(cls):
78    """Returns entity kind."""
79    return "_AE_MR_TaskPayload"
80
81
82class HugeTask(object):
83  """HugeTask is a taskqueue.Task-like class that can store big payloads.
84
85  Payloads are stored either in the task payload itself or in the datastore.
86  Task handlers should inherit from base_handler.HugeTaskHandler class.
87  """
88
89  PAYLOAD_PARAM = "__payload"
90  PAYLOAD_KEY_PARAM = "__payload_key"
91
92  # Leave some wiggle room for headers and other fields.
93  MAX_TASK_PAYLOAD = taskqueue.MAX_PUSH_TASK_SIZE_BYTES - 1024
94  MAX_DB_PAYLOAD = datastore_rpc.BaseConnection.MAX_RPC_BYTES
95
96  PAYLOAD_VERSION_HEADER = "AE-MR-Payload-Version"
97  # Update version when payload handling is changed
98  # in a backward incompatible way.
99  PAYLOAD_VERSION = "1"
100
101  def __init__(self,
102               url,
103               params,
104               name=None,
105               eta=None,
106               countdown=None,
107               parent=None,
108               headers=None):
109    """Init.
110
111    Args:
112      url: task url in str.
113      params: a dict from str to str.
114      name: task name.
115      eta: task eta.
116      countdown: task countdown.
117      parent: parent entity of huge task's payload.
118      headers: a dict of headers for the task.
119
120    Raises:
121      ValueError: when payload is too big even for datastore, or parent is
122    not specified when payload is stored in datastore.
123    """
124    self.url = url
125    self.name = name
126    self.eta = eta
127    self.countdown = countdown
128    self._headers = {
129        "Content-Type": "application/octet-stream",
130        self.PAYLOAD_VERSION_HEADER: self.PAYLOAD_VERSION
131    }
132    if headers:
133      self._headers.update(headers)
134
135    # TODO(user): Find a more space efficient way than urlencoding.
136    payload_str = urllib.urlencode(params)
137    compressed_payload = ""
138    if len(payload_str) > self.MAX_TASK_PAYLOAD:
139      compressed_payload = zlib.compress(payload_str)
140
141    # Payload is small. Don't bother with anything.
142    if not compressed_payload:
143      self._payload = payload_str
144    # Compressed payload is small. Don't bother with datastore.
145    elif len(compressed_payload) < self.MAX_TASK_PAYLOAD:
146      self._payload = self.PAYLOAD_PARAM + compressed_payload
147    elif len(compressed_payload) > self.MAX_DB_PAYLOAD:
148      raise ValueError(
149          "Payload from %s to big to be stored in database: %s" %
150          (self.name, len(compressed_payload)))
151    # Store payload in the datastore.
152    else:
153      if not parent:
154        raise ValueError("Huge tasks should specify parent entity.")
155
156      payload_entity = _HugeTaskPayload(payload=compressed_payload,
157                                        parent=parent)
158      payload_key = payload_entity.put()
159      self._payload = self.PAYLOAD_KEY_PARAM + str(payload_key)
160
161  def add(self, queue_name, transactional=False):
162    """Add task to the queue."""
163    task = self.to_task()
164    task.add(queue_name, transactional)
165
166  def to_task(self):
167    """Convert to a taskqueue task."""
168    # Never pass params to taskqueue.Task. Use payload instead. Otherwise,
169    # it's up to a particular taskqueue implementation to generate
170    # payload from params. It could blow up payload size over limit.
171    return taskqueue.Task(
172        url=self.url,
173        payload=self._payload,
174        name=self.name,
175        eta=self.eta,
176        countdown=self.countdown,
177        headers=self._headers)
178
179  @classmethod
180  def decode_payload(cls, request):
181    """Decode task payload.
182
183    HugeTask controls its own payload entirely including urlencoding.
184    It doesn't depend on any particular web framework.
185
186    Args:
187      request: a webapp Request instance.
188
189    Returns:
190      A dict of str to str. The same as the params argument to __init__.
191
192    Raises:
193      DeprecationWarning: When task payload constructed from an older
194        incompatible version of mapreduce.
195    """
196    # TODO(user): Pass mr_id into headers. Otherwise when payload decoding
197    # failed, we can't abort a mr.
198    if request.headers.get(cls.PAYLOAD_VERSION_HEADER) != cls.PAYLOAD_VERSION:
199      raise DeprecationWarning(
200          "Task is generated by an older incompatible version of mapreduce. "
201          "Please kill this job manually")
202    return cls._decode_payload(request.body)
203
204  @classmethod
205  def _decode_payload(cls, body):
206    compressed_payload_str = None
207    if body.startswith(cls.PAYLOAD_KEY_PARAM):
208      payload_key = body[len(cls.PAYLOAD_KEY_PARAM):]
209      payload_entity = _HugeTaskPayload.get(payload_key)
210      compressed_payload_str = payload_entity.payload
211    elif body.startswith(cls.PAYLOAD_PARAM):
212      compressed_payload_str = body[len(cls.PAYLOAD_PARAM):]
213
214    if compressed_payload_str:
215      payload_str = zlib.decompress(compressed_payload_str)
216    else:
217      payload_str = body
218
219    result = {}
220    for (name, value) in cgi.parse_qs(payload_str).items():
221      if len(value) == 1:
222        result[name] = value[0]
223      else:
224        result[name] = value
225    return result
226
227
228class CountersMap(json_util.JsonMixin):
229  """Maintains map from counter name to counter value.
230
231  The class is used to provide basic arithmetics of counter values (buil
232  add/remove), increment individual values and store/load data from json.
233  """
234
235  def __init__(self, initial_map=None):
236    """Constructor.
237
238    Args:
239      initial_map: initial counter values map from counter name (string) to
240        counter value (int).
241    """
242    if initial_map:
243      self.counters = initial_map
244    else:
245      self.counters = {}
246
247  def __repr__(self):
248    """Compute string representation."""
249    return "mapreduce.model.CountersMap(%r)" % self.counters
250
251  def get(self, counter_name, default=0):
252    """Get current counter value.
253
254    Args:
255      counter_name: counter name as string.
256      default: default value if one doesn't exist.
257
258    Returns:
259      current counter value as int. 0 if counter was not set.
260    """
261    return self.counters.get(counter_name, default)
262
263  def increment(self, counter_name, delta):
264    """Increment counter value.
265
266    Args:
267      counter_name: counter name as String.
268      delta: increment delta as Integer.
269
270    Returns:
271      new counter value.
272    """
273    current_value = self.counters.get(counter_name, 0)
274    new_value = current_value + delta
275    self.counters[counter_name] = new_value
276    return new_value
277
278  def add_map(self, counters_map):
279    """Add all counters from the map.
280
281    For each counter in the passed map, adds its value to the counter in this
282    map.
283
284    Args:
285      counters_map: CounterMap instance to add.
286    """
287    for counter_name in counters_map.counters:
288      self.increment(counter_name, counters_map.counters[counter_name])
289
290  def sub_map(self, counters_map):
291    """Subtracts all counters from the map.
292
293    For each counter in the passed map, subtracts its value to the counter in
294    this map.
295
296    Args:
297      counters_map: CounterMap instance to subtract.
298    """
299    for counter_name in counters_map.counters:
300      self.increment(counter_name, -counters_map.counters[counter_name])
301
302  def clear(self):
303    """Clear all values."""
304    self.counters = {}
305
306  def to_json(self):
307    """Serializes all the data in this map into json form.
308
309    Returns:
310      json-compatible data representation.
311    """
312    return {"counters": self.counters}
313
314  @classmethod
315  def from_json(cls, json):
316    """Create new CountersMap from the json data structure, encoded by to_json.
317
318    Args:
319      json: json representation of CountersMap .
320
321    Returns:
322      an instance of CountersMap with all data deserialized from json.
323    """
324    counters_map = cls()
325    counters_map.counters = json["counters"]
326    return counters_map
327
328  def to_dict(self):
329    """Convert to dictionary.
330
331    Returns:
332      a dictionary with counter name as key and counter values as value.
333    """
334    return self.counters
335
336
337class MapperSpec(json_util.JsonMixin):
338  """Contains a specification for the mapper phase of the mapreduce.
339
340  MapperSpec instance can be changed only during mapreduce starting process,
341  and it remains immutable for the rest of mapreduce execution. MapperSpec is
342  passed as a payload to all mapreduce tasks in JSON encoding as part of
343  MapreduceSpec.
344
345  Specifying mapper handlers:
346    * '<module_name>.<class_name>' - __call__ method of class instance will be
347      called
348    * '<module_name>.<function_name>' - function will be called.
349    * '<module_name>.<class_name>.<method_name>' - class will be instantiated
350      and method called.
351  """
352
353  def __init__(self,
354               handler_spec,
355               input_reader_spec,
356               params,
357               shard_count,
358               output_writer_spec=None):
359    """Creates a new MapperSpec.
360
361    Args:
362      handler_spec: handler specification as string (see class doc for
363        details).
364      input_reader_spec: The class name of the input reader to use.
365      params: Dictionary of additional parameters for the mapper.
366      shard_count: number of shards to process in parallel.
367
368    Properties:
369      handler_spec: name of handler class/function to use.
370      input_reader_spec: The class name of the input reader to use.
371      params: Dictionary of additional parameters for the mapper.
372      shard_count: number of shards to process in parallel.
373      output_writer_spec: The class name of the output writer to use.
374    """
375    self.handler_spec = handler_spec
376    self.input_reader_spec = input_reader_spec
377    self.output_writer_spec = output_writer_spec
378    self.shard_count = int(shard_count)
379    self.params = params
380
381  def get_handler(self):
382    """Get mapper handler instance.
383
384    This always creates a new instance of the handler. If the handler is a
385    callable instance, MR only wants to create a new instance at the
386    beginning of a shard or shard retry. The pickled callable instance
387    should be accessed from TransientShardState.
388
389    Returns:
390      handler instance as callable.
391    """
392    return util.handler_for_name(self.handler_spec)
393
394  handler = property(get_handler)
395
396  def input_reader_class(self):
397    """Get input reader class.
398
399    Returns:
400      input reader class object.
401    """
402    return util.for_name(self.input_reader_spec)
403
404  def output_writer_class(self):
405    """Get output writer class.
406
407    Returns:
408      output writer class object.
409    """
410    return self.output_writer_spec and util.for_name(self.output_writer_spec)
411
412  def to_json(self):
413    """Serializes this MapperSpec into a json-izable object."""
414    result = {
415        "mapper_handler_spec": self.handler_spec,
416        "mapper_input_reader": self.input_reader_spec,
417        "mapper_params": self.params,
418        "mapper_shard_count": self.shard_count
419    }
420    if self.output_writer_spec:
421      result["mapper_output_writer"] = self.output_writer_spec
422    return result
423
424  def __str__(self):
425    return "MapperSpec(%s, %s, %s, %s)" % (
426        self.handler_spec, self.input_reader_spec, self.params,
427        self.shard_count)
428
429  @classmethod
430  def from_json(cls, json):
431    """Creates MapperSpec from a dict-like object."""
432    return cls(json["mapper_handler_spec"],
433               json["mapper_input_reader"],
434               json["mapper_params"],
435               json["mapper_shard_count"],
436               json.get("mapper_output_writer")
437              )
438
439  def __eq__(self, other):
440    if not isinstance(other, self.__class__):
441      return False
442    return self.to_json() == other.to_json()
443
444
445class MapreduceSpec(json_util.JsonMixin):
446  """Contains a specification for the whole mapreduce.
447
448  MapreduceSpec instance can be changed only during mapreduce starting process,
449  and it remains immutable for the rest of mapreduce execution. MapreduceSpec is
450  passed as a payload to all mapreduce tasks in json encoding.
451  """
452
453  # Url to call when mapreduce finishes its execution.
454  PARAM_DONE_CALLBACK = "done_callback"
455  # Queue to use to call done callback
456  PARAM_DONE_CALLBACK_QUEUE = "done_callback_queue"
457
458  def __init__(self,
459               name,
460               mapreduce_id,
461               mapper_spec,
462               params={},
463               hooks_class_name=None):
464    """Create new MapreduceSpec.
465
466    Args:
467      name: The name of this mapreduce job type.
468      mapreduce_id: ID of the mapreduce.
469      mapper_spec: JSON-encoded string containing a MapperSpec.
470      params: dictionary of additional mapreduce parameters.
471      hooks_class_name: The fully qualified name of the hooks class to use.
472
473    Properties:
474      name: The name of this mapreduce job type.
475      mapreduce_id: unique id of this mapreduce as string.
476      mapper: This MapreduceSpec's instance of MapperSpec.
477      params: dictionary of additional mapreduce parameters.
478      hooks_class_name: The fully qualified name of the hooks class to use.
479    """
480    self.name = name
481    self.mapreduce_id = mapreduce_id
482    self.mapper = MapperSpec.from_json(mapper_spec)
483    self.params = params
484    self.hooks_class_name = hooks_class_name
485    self.__hooks = None
486    self.get_hooks()  # Fail fast on an invalid hook class.
487
488  def get_hooks(self):
489    """Returns a hooks.Hooks class or None if no hooks class has been set."""
490    if self.__hooks is None and self.hooks_class_name is not None:
491      hooks_class = util.for_name(self.hooks_class_name)
492      if not isinstance(hooks_class, type):
493        raise ValueError("hooks_class_name must refer to a class, got %s" %
494                         type(hooks_class).__name__)
495      if not issubclass(hooks_class, hooks.Hooks):
496        raise ValueError(
497            "hooks_class_name must refer to a hooks.Hooks subclass")
498      self.__hooks = hooks_class(self)
499
500    return self.__hooks
501
502  def to_json(self):
503    """Serializes all data in this mapreduce spec into json form.
504
505    Returns:
506      data in json format.
507    """
508    mapper_spec = self.mapper.to_json()
509    return {
510        "name": self.name,
511        "mapreduce_id": self.mapreduce_id,
512        "mapper_spec": mapper_spec,
513        "params": self.params,
514        "hooks_class_name": self.hooks_class_name,
515    }
516
517  @classmethod
518  def from_json(cls, json):
519    """Create new MapreduceSpec from the json, encoded by to_json.
520
521    Args:
522      json: json representation of MapreduceSpec.
523
524    Returns:
525      an instance of MapreduceSpec with all data deserialized from json.
526    """
527    mapreduce_spec = cls(json["name"],
528                         json["mapreduce_id"],
529                         json["mapper_spec"],
530                         json.get("params"),
531                         json.get("hooks_class_name"))
532    return mapreduce_spec
533
534  def __str__(self):
535    return str(self.to_json())
536
537  def __eq__(self, other):
538    if not isinstance(other, self.__class__):
539      return False
540    return self.to_json() == other.to_json()
541
542  @classmethod
543  def _get_mapreduce_spec(cls, mr_id):
544    """Get Mapreduce spec from mr id."""
545    key = 'GAE-MR-spec: %s' % mr_id
546    spec_json = memcache.get(key)
547    if spec_json:
548      return cls.from_json(spec_json)
549    state = MapreduceState.get_by_job_id(mr_id)
550    spec = state.mapreduce_spec
551    spec_json = spec.to_json()
552    memcache.set(key, spec_json)
553    return spec
554
555
556class MapreduceState(db.Model):
557  """Holds accumulated state of mapreduce execution.
558
559  MapreduceState is stored in datastore with a key name equal to the
560  mapreduce ID. Only controller tasks can write to MapreduceState.
561
562  Properties:
563    mapreduce_spec: cached deserialized MapreduceSpec instance. read-only
564    active: if this MR is still running.
565    last_poll_time: last time controller job has polled this mapreduce.
566    counters_map: shard's counters map as CountersMap. Mirrors
567      counters_map_json.
568    chart_url: last computed mapreduce status chart url. This chart displays the
569      progress of all the shards the best way it can.
570    sparkline_url: last computed mapreduce status chart url in small format.
571    result_status: If not None, the final status of the job.
572    active_shards: How many shards are still processing. This starts as 0,
573      then set by KickOffJob handler to be the actual number of input
574      readers after input splitting, and is updated by Controller task
575      as shards finish.
576    start_time: When the job started.
577    writer_state: Json property to be used by writer to store its state.
578      This is filled when single output per job. Will be deprecated.
579      Use OutputWriter.get_filenames instead.
580  """
581
582  RESULT_SUCCESS = "success"
583  RESULT_FAILED = "failed"
584  RESULT_ABORTED = "aborted"
585
586  _RESULTS = frozenset([RESULT_SUCCESS, RESULT_FAILED, RESULT_ABORTED])
587
588  # Functional properties.
589  # TODO(user): Replace mapreduce_spec with job_config.
590  mapreduce_spec = json_util.JsonProperty(MapreduceSpec, indexed=False)
591  active = db.BooleanProperty(default=True, indexed=False)
592  last_poll_time = db.DateTimeProperty(required=True)
593  counters_map = json_util.JsonProperty(
594      CountersMap, default=CountersMap(), indexed=False)
595  app_id = db.StringProperty(required=False, indexed=True)
596  writer_state = json_util.JsonProperty(dict, indexed=False)
597  active_shards = db.IntegerProperty(default=0, indexed=False)
598  failed_shards = db.IntegerProperty(default=0, indexed=False)
599  aborted_shards = db.IntegerProperty(default=0, indexed=False)
600  result_status = db.StringProperty(required=False, choices=_RESULTS)
601
602  # For UI purposes only.
603  chart_url = db.TextProperty(default="")
604  chart_width = db.IntegerProperty(default=300, indexed=False)
605  sparkline_url = db.TextProperty(default="")
606  start_time = db.DateTimeProperty(auto_now_add=True)
607
608  @classmethod
609  def kind(cls):
610    """Returns entity kind."""
611    return "_AE_MR_MapreduceState"
612
613  @classmethod
614  def get_key_by_job_id(cls, mapreduce_id):
615    """Retrieves the Key for a Job.
616
617    Args:
618      mapreduce_id: The job to retrieve.
619
620    Returns:
621      Datastore Key that can be used to fetch the MapreduceState.
622    """
623    return db.Key.from_path(cls.kind(), str(mapreduce_id))
624
625  @classmethod
626  def get_by_job_id(cls, mapreduce_id):
627    """Retrieves the instance of state for a Job.
628
629    Args:
630      mapreduce_id: The mapreduce job to retrieve.
631
632    Returns:
633      instance of MapreduceState for passed id.
634    """
635    return db.get(cls.get_key_by_job_id(mapreduce_id))
636
637  def set_processed_counts(self, shards_processed, shards_status):
638    """Updates a chart url to display processed count for each shard.
639
640    Args:
641      shards_processed: list of integers with number of processed entities in
642        each shard
643    """
644    chart = google_chart_api.BarChart()
645
646    def filter_status(status_to_filter):
647      return [count if status == status_to_filter else 0
648              for count, status in zip(shards_processed, shards_status)]
649
650    if shards_status:
651      # Each index will have only one non-zero count, so stack them to color-
652      # code the bars by status
653      # These status values are computed in _update_state_from_shard_states,
654      # in mapreduce/handlers.py.
655      chart.stacked = True
656      chart.AddBars(filter_status("unknown"), color="404040")
657      chart.AddBars(filter_status("success"), color="00ac42")
658      chart.AddBars(filter_status("running"), color="3636a9")
659      chart.AddBars(filter_status("aborted"), color="e29e24")
660      chart.AddBars(filter_status("failed"), color="f6350f")
661    else:
662      chart.AddBars(shards_processed)
663
664    shard_count = len(shards_processed)
665
666    if shard_count > 95:
667      # Auto-spacing does not work for large numbers of shards.
668      pixels_per_shard = 700.0 / shard_count
669      bar_thickness = int(pixels_per_shard * .9)
670
671      chart.style = bar_chart.BarChartStyle(bar_thickness=bar_thickness,
672        bar_gap=0.1, use_fractional_gap_spacing=True)
673
674    if shards_processed and shard_count <= 95:
675      # Adding labels puts us in danger of exceeding the URL length, only
676      # do it when we have a small amount of data to plot.
677      # Only 16 labels on the whole chart.
678      stride_length = max(1, shard_count / 16)
679      chart.bottom.labels = []
680      for x in xrange(shard_count):
681        if (x % stride_length == 0 or
682            x == shard_count - 1):
683          chart.bottom.labels.append(x)
684        else:
685          chart.bottom.labels.append("")
686      chart.left.labels = ["0", str(max(shards_processed))]
687      chart.left.min = 0
688
689    self.chart_width = min(700, max(300, shard_count * 20))
690    self.chart_url = chart.display.Url(self.chart_width, 200)
691
692  def get_processed(self):
693    """Number of processed entities.
694
695    Returns:
696      The total number of processed entities as int.
697    """
698    return self.counters_map.get(context.COUNTER_MAPPER_CALLS)
699
700  processed = property(get_processed)
701
702  @staticmethod
703  def create_new(mapreduce_id=None,
704                 gettime=datetime.datetime.now):
705    """Create a new MapreduceState.
706
707    Args:
708      mapreduce_id: Mapreduce id as string.
709      gettime: Used for testing.
710    """
711    if not mapreduce_id:
712      mapreduce_id = MapreduceState.new_mapreduce_id()
713    state = MapreduceState(key_name=mapreduce_id,
714                           last_poll_time=gettime())
715    state.set_processed_counts([], [])
716    return state
717
718  @staticmethod
719  def new_mapreduce_id():
720    """Generate new mapreduce id."""
721    return util._get_descending_key()
722
723  def __eq__(self, other):
724    if not isinstance(other, self.__class__):
725      return False
726    return self.properties() == other.properties()
727
728
729class TransientShardState(object):
730  """A shard's states that are kept in task payload.
731
732  TransientShardState holds two types of states:
733  1. Some states just don't need to be saved to datastore. e.g.
734     serialized input reader and output writer instances.
735  2. Some states are duplicated from datastore, e.g. slice_id, shard_id.
736     These are used to validate the task.
737  """
738
739  def __init__(self,
740               base_path,
741               mapreduce_spec,
742               shard_id,
743               slice_id,
744               input_reader,
745               initial_input_reader,
746               output_writer=None,
747               retries=0,
748               handler=None):
749    """Init.
750
751    Args:
752      base_path: base path of this mapreduce job. Deprecated.
753      mapreduce_spec: an instance of MapReduceSpec.
754      shard_id: shard id.
755      slice_id: slice id. When enqueuing task for the next slice, this number
756        is incremented by 1.
757      input_reader: input reader instance for this shard.
758      initial_input_reader: the input reader instance before any iteration.
759        Used by shard retry.
760      output_writer: output writer instance for this shard, if exists.
761      retries: the number of retries of the current shard. Used to drop
762        tasks from old retries.
763      handler: map/reduce handler.
764    """
765    self.base_path = base_path
766    self.mapreduce_spec = mapreduce_spec
767    self.shard_id = shard_id
768    self.slice_id = slice_id
769    self.input_reader = input_reader
770    self.initial_input_reader = initial_input_reader
771    self.output_writer = output_writer
772    self.retries = retries
773    self.handler = handler
774    self._input_reader_json = self.input_reader.to_json()
775
776  def reset_for_retry(self, output_writer):
777    """Reset self for shard retry.
778
779    Args:
780      output_writer: new output writer that contains new output files.
781    """
782    self.input_reader = self.initial_input_reader
783    self.slice_id = 0
784    self.retries += 1
785    self.output_writer = output_writer
786    self.handler = self.mapreduce_spec.mapper.handler
787
788  def advance_for_next_slice(self, recovery_slice=False):
789    """Advance relavent states for next slice.
790
791    Args:
792      recovery_slice: True if this slice is running recovery logic.
793        See handlers.MapperWorkerCallbackHandler._attempt_slice_recovery
794        for more info.
795    """
796    if recovery_slice:
797      self.slice_id += 2
798      # Restore input reader to the beginning of the slice.
799      self.input_reader = self.input_reader.from_json(self._input_reader_json)
800    else:
801      self.slice_id += 1
802
803  def to_dict(self):
804    """Convert state to dictionary to save in task payload."""
805    result = {"mapreduce_spec": self.mapreduce_spec.to_json_str(),
806              "shard_id": self.shard_id,
807              "slice_id": str(self.slice_id),
808              "input_reader_state": self.input_reader.to_json_str(),
809              "initial_input_reader_state":
810              self.initial_input_reader.to_json_str(),
811              "retries": str(self.retries)}
812    if self.output_writer:
813      result["output_writer_state"] = self.output_writer.to_json_str()
814    serialized_handler = util.try_serialize_handler(self.handler)
815    if serialized_handler:
816      result["serialized_handler"] = serialized_handler
817    return result
818
819  @classmethod
820  def from_request(cls, request):
821    """Create new TransientShardState from webapp request."""
822    mapreduce_spec = MapreduceSpec.from_json_str(request.get("mapreduce_spec"))
823    mapper_spec = mapreduce_spec.mapper
824    input_reader_spec_dict = json.loads(request.get("input_reader_state"),
825                                        cls=json_util.JsonDecoder)
826    input_reader = mapper_spec.input_reader_class().from_json(
827        input_reader_spec_dict)
828    initial_input_reader_spec_dict = json.loads(
829        request.get("initial_input_reader_state"), cls=json_util.JsonDecoder)
830    initial_input_reader = mapper_spec.input_reader_class().from_json(
831        initial_input_reader_spec_dict)
832
833    output_writer = None
834    if mapper_spec.output_writer_class():
835      output_writer = mapper_spec.output_writer_class().from_json(
836          json.loads(request.get("output_writer_state", "{}"),
837                     cls=json_util.JsonDecoder))
838      assert isinstance(output_writer, mapper_spec.output_writer_class()), (
839          "%s.from_json returned an instance of wrong class: %s" % (
840              mapper_spec.output_writer_class(),
841              output_writer.__class__))
842
843    handler = util.try_deserialize_handler(request.get("serialized_handler"))
844    if not handler:
845      handler = mapreduce_spec.mapper.handler
846
847    return cls(mapreduce_spec.params["base_path"],
848               mapreduce_spec,
849               str(request.get("shard_id")),
850               int(request.get("slice_id")),
851               input_reader,
852               initial_input_reader,
853               output_writer=output_writer,
854               retries=int(request.get("retries")),
855               handler=handler)
856
857
858class ShardState(db.Model):
859  """Single shard execution state.
860
861  The shard state is stored in the datastore and is later aggregated by
862  controller task. ShardState key_name is equal to shard_id.
863
864  Shard state contains critical state to ensure the correctness of
865  shard execution. It is the single source of truth about a shard's
866  progress. For example:
867  1. A slice is allowed to run only if its payload matches shard state's
868     expectation.
869  2. A slice is considered running only if it has acquired the shard's lock.
870  3. A slice is considered done only if it has successfully committed shard
871     state to db.
872
873  Properties about the shard:
874    active: if we have this shard still running as boolean.
875    counters_map: shard's counters map as CountersMap. All counters yielded
876      within mapreduce are stored here.
877    mapreduce_id: unique id of the mapreduce.
878    shard_id: unique id of this shard as string.
879    shard_number: ordered number for this shard.
880    retries: the number of times this shard has been retried.
881    result_status: If not None, the final status of this shard.
882    update_time: The last time this shard state was updated.
883    shard_description: A string description of the work this shard will do.
884    last_work_item: A string description of the last work item processed.
885    writer_state: writer state for this shard. The shard's output writer
886      instance can save in-memory output references to this field in its
887      "finalize" method.
888
889   Properties about slice management:
890    slice_id: slice id of current executing slice. A slice's task
891      will not run unless its slice_id matches this. Initial
892      value is 0. By the end of slice execution, this number is
893      incremented by 1.
894    slice_start_time: a slice updates this to now at the beginning of
895      execution. If the transaction succeeds, the current task holds
896      a lease of slice duration + some grace period. During this time, no
897      other task with the same slice_id will execute. Upon slice failure,
898      the task should try to unset this value to allow retries to carry on
899      ASAP.
900    slice_request_id: the request id that holds/held the lease. When lease has
901      expired, new request needs to verify that said request has indeed
902      ended according to logs API. Do this only when lease has expired
903      because logs API is expensive. This field should always be set/unset
904      with slice_start_time. It is possible Logs API doesn't log a request
905      at all or doesn't log the end of a request. So a new request can
906      proceed after a long conservative timeout.
907    slice_retries: the number of times a slice has been retried due to
908      processing data when lock is held. Taskqueue/datastore errors
909      related to slice/shard management are not counted. This count is
910      only a lower bound and is used to determined when to fail a slice
911      completely.
912    acquired_once: whether the lock for this slice has been acquired at
913      least once. When this is True, duplicates in outputs are possible.
914  """
915
916  RESULT_SUCCESS = "success"
917  RESULT_FAILED = "failed"
918  # Shard can be in aborted state when user issued abort, or controller
919  # issued abort because some other shard failed.
920  RESULT_ABORTED = "aborted"
921
922  _RESULTS = frozenset([RESULT_SUCCESS, RESULT_FAILED, RESULT_ABORTED])
923
924  # Maximum number of shard states to hold in memory at any time.
925  _MAX_STATES_IN_MEMORY = 10
926
927  # Functional properties.
928  mapreduce_id = db.StringProperty(required=True)
929  active = db.BooleanProperty(default=True, indexed=False)
930  input_finished = db.BooleanProperty(default=False, indexed=False)
931  counters_map = json_util.JsonProperty(
932      CountersMap, default=CountersMap(), indexed=False)
933  result_status = db.StringProperty(choices=_RESULTS, indexed=False)
934  retries = db.IntegerProperty(default=0, indexed=False)
935  writer_state = json_util.JsonProperty(dict, indexed=False)
936  slice_id = db.IntegerProperty(default=0, indexed=False)
937  slice_start_time = db.DateTimeProperty(indexed=False)
938  slice_request_id = db.ByteStringProperty(indexed=False)
939  slice_retries = db.IntegerProperty(default=0, indexed=False)
940  acquired_once = db.BooleanProperty(default=False, indexed=False)
941
942  # For UI purposes only.
943  update_time = db.DateTimeProperty(auto_now=True, indexed=False)
944  shard_description = db.TextProperty(default="")
945  last_work_item = db.TextProperty(default="")
946
947  def __str__(self):
948    kv = {"active": self.active,
949          "slice_id": self.slice_id,
950          "last_work_item": self.last_work_item,
951          "update_time": self.update_time}
952    if self.result_status:
953      kv["result_status"] = self.result_status
954    if self.retries:
955      kv["retries"] = self.retries
956    if self.slice_start_time:
957      kv["slice_start_time"] = self.slice_start_time
958    if self.slice_retries:
959      kv["slice_retries"] = self.slice_retries
960    if self.slice_request_id:
961      kv["slice_request_id"] = self.slice_request_id
962    if self.acquired_once:
963      kv["acquired_once"] = self.acquired_once
964    keys = kv.keys()
965    keys.sort()
966
967    result = "ShardState is {"
968    for k in keys:
969      result += k + ":" + str(kv[k]) + ","
970    result += "}"
971    return result
972
973  def reset_for_retry(self):
974    """Reset self for shard retry."""
975    self.retries += 1
976    self.last_work_item = ""
977    self.active = True
978    self.result_status = None
979    self.input_finished = False
980    self.counters_map = CountersMap()
981    self.slice_id = 0
982    self.slice_start_time = None
983    self.slice_request_id = None
984    self.slice_retries = 0
985    self.acquired_once = False
986
987  def advance_for_next_slice(self, recovery_slice=False):
988    """Advance self for next slice.
989
990    Args:
991      recovery_slice: True if this slice is running recovery logic.
992        See handlers.MapperWorkerCallbackHandler._attempt_slice_recovery
993        for more info.
994    """
995    self.slice_start_time = None
996    self.slice_request_id = None
997    self.slice_retries = 0
998    self.acquired_once = False
999    if recovery_slice:
1000      self.slice_id += 2
1001    else:
1002      self.slice_id += 1
1003
1004  def set_for_failure(self):
1005    self.active = False
1006    self.result_status = self.RESULT_FAILED
1007
1008  def set_for_abort(self):
1009    self.active = False
1010    self.result_status = self.RESULT_ABORTED
1011
1012  def set_input_finished(self):
1013    self.input_finished = True
1014
1015  def is_input_finished(self):
1016    return self.input_finished
1017
1018  def set_for_success(self):
1019    self.active = False
1020    self.result_status = self.RESULT_SUCCESS
1021    self.slice_start_time = None
1022    self.slice_request_id = None
1023    self.slice_retries = 0
1024    self.acquired_once = False
1025
1026  def copy_from(self, other_state):
1027    """Copy data from another shard state entity to self."""
1028    for prop in self.properties().values():
1029      setattr(self, prop.name, getattr(other_state, prop.name))
1030
1031  def __eq__(self, other):
1032    if not isinstance(other, self.__class__):
1033      return False
1034    return self.properties() == other.properties()
1035
1036  def get_shard_number(self):
1037    """Gets the shard number from the key name."""
1038    return int(self.key().name().split("-")[-1])
1039
1040  shard_number = property(get_shard_number)
1041
1042  def get_shard_id(self):
1043    """Returns the shard ID."""
1044    return self.key().name()
1045
1046  shard_id = property(get_shard_id)
1047
1048  @classmethod
1049  def kind(cls):
1050    """Returns entity kind."""
1051    return "_AE_MR_ShardState"
1052
1053  @classmethod
1054  def shard_id_from_number(cls, mapreduce_id, shard_number):
1055    """Get shard id by mapreduce id and shard number.
1056
1057    Args:
1058      mapreduce_id: mapreduce id as string.
1059      shard_number: shard number to compute id for as int.
1060
1061    Returns:
1062      shard id as string.
1063    """
1064    return "%s-%d" % (mapreduce_id, shard_number)
1065
1066  @classmethod
1067  def get_key_by_shard_id(cls, shard_id):
1068    """Retrieves the Key for this ShardState.
1069
1070    Args:
1071      shard_id: The shard ID to fetch.
1072
1073    Returns:
1074      The Datatore key to use to retrieve this ShardState.
1075    """
1076    return db.Key.from_path(cls.kind(), shard_id)
1077
1078  @classmethod
1079  def get_by_shard_id(cls, shard_id):
1080    """Get shard state from datastore by shard_id.
1081
1082    Args:
1083      shard_id: shard id as string.
1084
1085    Returns:
1086      ShardState for given shard id or None if it's not found.
1087    """
1088    return cls.get_by_key_name(shard_id)
1089
1090  @classmethod
1091  def find_by_mapreduce_state(cls, mapreduce_state):
1092    """Find all shard states for given mapreduce.
1093
1094    Deprecated. Use find_all_by_mapreduce_state.
1095    This will be removed after 1.8.9 release.
1096
1097    Args:
1098      mapreduce_state: MapreduceState instance
1099
1100    Returns:
1101      A list of ShardStates.
1102    """
1103    return list(cls.find_all_by_mapreduce_state(mapreduce_state))
1104
1105  @classmethod
1106  def find_all_by_mapreduce_state(cls, mapreduce_state):
1107    """Find all shard states for given mapreduce.
1108
1109    Args:
1110      mapreduce_state: MapreduceState instance
1111
1112    Yields:
1113      shard states sorted by shard id.
1114    """
1115    keys = cls.calculate_keys_by_mapreduce_state(mapreduce_state)
1116    i = 0
1117    while i < len(keys):
1118      @db.non_transactional
1119      def no_tx_get(i):
1120        return db.get(keys[i:i+cls._MAX_STATES_IN_MEMORY])
1121      # We need a separate function to so that we can mix non-transactional and
1122      # use be a generator
1123      states = no_tx_get(i)
1124      for s in states:
1125        i += 1
1126        if s is not None:
1127          yield s
1128
1129  @classmethod
1130  def calculate_keys_by_mapreduce_state(cls, mapreduce_state):
1131    """Calculate all shard states keys for given mapreduce.
1132
1133    Args:
1134      mapreduce_state: MapreduceState instance
1135
1136    Returns:
1137      A list of keys for shard states, sorted by shard id.
1138      The corresponding shard states may not exist.
1139    """
1140    if mapreduce_state is None:
1141      return []
1142
1143    keys = []
1144    for i in range(mapreduce_state.mapreduce_spec.mapper.shard_count):
1145      shard_id = cls.shard_id_from_number(mapreduce_state.key().name(), i)
1146      keys.append(cls.get_key_by_shard_id(shard_id))
1147    return keys
1148
1149  @classmethod
1150  def create_new(cls, mapreduce_id, shard_number):
1151    """Create new shard state.
1152
1153    Args:
1154      mapreduce_id: unique mapreduce id as string.
1155      shard_number: shard number for which to create shard state.
1156
1157    Returns:
1158      new instance of ShardState ready to put into datastore.
1159    """
1160    shard_id = cls.shard_id_from_number(mapreduce_id, shard_number)
1161    state = cls(key_name=shard_id,
1162                mapreduce_id=mapreduce_id)
1163    return state
1164
1165
1166class MapreduceControl(db.Model):
1167  """Datastore entity used to control mapreduce job execution.
1168
1169  Only one command may be sent to jobs at a time.
1170
1171  Properties:
1172    command: The command to send to the job.
1173  """
1174
1175  ABORT = "abort"
1176
1177  _COMMANDS = frozenset([ABORT])
1178  _KEY_NAME = "command"
1179
1180  command = db.TextProperty(choices=_COMMANDS, required=True)
1181
1182  @classmethod
1183  def kind(cls):
1184    """Returns entity kind."""
1185    return "_AE_MR_MapreduceControl"
1186
1187  @classmethod
1188  def get_key_by_job_id(cls, mapreduce_id):
1189    """Retrieves the Key for a mapreduce ID.
1190
1191    Args:
1192      mapreduce_id: The job to fetch.
1193
1194    Returns:
1195      Datastore Key for the command for the given job ID.
1196    """
1197    return db.Key.from_path(cls.kind(), "%s:%s" % (mapreduce_id, cls._KEY_NAME))
1198
1199  @classmethod
1200  def abort(cls, mapreduce_id, **kwargs):
1201    """Causes a job to abort.
1202
1203    Args:
1204      mapreduce_id: The job to abort. Not verified as a valid job.
1205    """
1206    cls(key_name="%s:%s" % (mapreduce_id, cls._KEY_NAME),
1207        command=cls.ABORT).put(**kwargs)
1208
1209
1210class QuerySpec(object):
1211  """Encapsulates everything about a query needed by DatastoreInputReader."""
1212
1213  DEFAULT_BATCH_SIZE = 50
1214  DEFAULT_OVERSPLIT_FACTOR = 1
1215
1216  def __init__(self,
1217               entity_kind,
1218               keys_only=None,
1219               filters=None,
1220               batch_size=None,
1221               oversplit_factor=None,
1222               model_class_path=None,
1223               app=None,
1224               ns=None):
1225    self.entity_kind = entity_kind
1226    self.keys_only = keys_only or False
1227    self.filters = filters or None
1228    self.batch_size = batch_size or self.DEFAULT_BATCH_SIZE
1229    self.oversplit_factor = (oversplit_factor or
1230                             self.DEFAULT_OVERSPLIT_FACTOR)
1231    self.model_class_path = model_class_path
1232    self.app = app
1233    self.ns = ns
1234
1235  def to_json(self):
1236    return {"entity_kind": self.entity_kind,
1237            "keys_only": self.keys_only,
1238            "filters": self.filters,
1239            "batch_size": self.batch_size,
1240            "oversplit_factor": self.oversplit_factor,
1241            "model_class_path": self.model_class_path,
1242            "app": self.app,
1243            "ns": self.ns}
1244
1245  @classmethod
1246  def from_json(cls, json):
1247    return cls(json["entity_kind"],
1248               json["keys_only"],
1249               json["filters"],
1250               json["batch_size"],
1251               json["oversplit_factor"],
1252               json["model_class_path"],
1253               json["app"],
1254               json["ns"])
1255