1#!/usr/bin/env python
2"""Model Datastore Input Reader implementation for the map_job API."""
3import copy
4
5from google.appengine.ext import ndb
6
7from google.appengine.ext import db
8from mapreduce import datastore_range_iterators as db_iters
9from mapreduce import errors
10from mapreduce import namespace_range
11from mapreduce import property_range
12from mapreduce import util
13from mapreduce.api.map_job import abstract_datastore_input_reader
14
15# pylint: disable=invalid-name
16
17
18class ModelDatastoreInputReader(abstract_datastore_input_reader
19                                .AbstractDatastoreInputReader):
20  """Implementation of an input reader for Datastore.
21
22  Iterates over a Model and yields model instances.
23  Supports both db.model and ndb.model.
24  """
25
26  _KEY_RANGE_ITER_CLS = db_iters.KeyRangeModelIterator
27
28  @classmethod
29  def _get_raw_entity_kind(cls, model_classpath):
30    entity_type = util.for_name(model_classpath)
31    if isinstance(entity_type, db.Model):
32      return entity_type.kind()
33    elif isinstance(entity_type, (ndb.Model, ndb.MetaModel)):
34      # pylint: disable=protected-access
35      return entity_type._get_kind()
36    else:
37      return util.get_short_name(model_classpath)
38
39  @classmethod
40  def split_input(cls, job_config):
41    """Inherit docs."""
42    params = job_config.input_reader_params
43    shard_count = job_config.shard_count
44    query_spec = cls._get_query_spec(params)
45
46    if not property_range.should_shard_by_property_range(query_spec.filters):
47      return super(ModelDatastoreInputReader, cls).split_input(job_config)
48
49    p_range = property_range.PropertyRange(query_spec.filters,
50                                           query_spec.model_class_path)
51    p_ranges = p_range.split(shard_count)
52
53    # User specified a namespace.
54    if query_spec.ns:
55      ns_range = namespace_range.NamespaceRange(
56          namespace_start=query_spec.ns,
57          namespace_end=query_spec.ns,
58          _app=query_spec.app)
59      ns_ranges = [copy.copy(ns_range) for _ in p_ranges]
60    else:
61      ns_keys = namespace_range.get_namespace_keys(
62          query_spec.app, cls.MAX_NAMESPACES_FOR_KEY_SHARD+1)
63      if not ns_keys:
64        return
65      # User doesn't specify ns but the number of ns is small.
66      # We still split by property range.
67      if len(ns_keys) <= cls.MAX_NAMESPACES_FOR_KEY_SHARD:
68        ns_ranges = [namespace_range.NamespaceRange(_app=query_spec.app)
69                     for _ in p_ranges]
70      # Lots of namespaces. Split by ns.
71      else:
72        ns_ranges = namespace_range.NamespaceRange.split(n=shard_count,
73                                                         contiguous=False,
74                                                         can_query=lambda: True,
75                                                         _app=query_spec.app)
76        p_ranges = [copy.copy(p_range) for _ in ns_ranges]
77
78    assert len(p_ranges) == len(ns_ranges)
79
80    iters = [
81        db_iters.RangeIteratorFactory.create_property_range_iterator(
82            p, ns, query_spec) for p, ns in zip(p_ranges, ns_ranges)]
83    return [cls(i) for i in iters]
84
85  @classmethod
86  def validate(cls, job_config):
87    """Inherit docs."""
88    super(ModelDatastoreInputReader, cls).validate(job_config)
89    params = job_config.input_reader_params
90    entity_kind = params[cls.ENTITY_KIND_PARAM]
91    # Fail fast if Model cannot be located.
92    try:
93      model_class = util.for_name(entity_kind)
94    except ImportError, e:
95      raise errors.BadReaderParamsError("Bad entity kind: %s" % e)
96    if cls.FILTERS_PARAM in params:
97      filters = params[cls.FILTERS_PARAM]
98      if issubclass(model_class, db.Model):
99        cls._validate_filters(filters, model_class)
100      else:
101        cls._validate_filters_ndb(filters, model_class)
102      property_range.PropertyRange(filters, entity_kind)
103
104  @classmethod
105  def _validate_filters(cls, filters, model_class):
106    """Validate user supplied filters.
107
108    Validate filters are on existing properties and filter values
109    have valid semantics.
110
111    Args:
112      filters: user supplied filters. Each filter should be a list or tuple of
113        format (<property_name_as_str>, <query_operator_as_str>,
114        <value_of_certain_type>). Value type is up to the property's type.
115      model_class: the db.Model class for the entity type to apply filters on.
116
117    Raises:
118      BadReaderParamsError: if any filter is invalid in any way.
119    """
120    if not filters:
121      return
122
123    properties = model_class.properties()
124
125    for f in filters:
126      prop, _, val = f
127      if prop not in properties:
128        raise errors.BadReaderParamsError(
129            "Property %s is not defined for entity type %s",
130            prop, model_class.kind())
131
132      # Validate the value of each filter. We need to know filters have
133      # valid value to carry out splits.
134      try:
135        properties[prop].validate(val)
136      except db.BadValueError, e:
137        raise errors.BadReaderParamsError(e)
138
139  @classmethod
140  # pylint: disable=protected-access
141  def _validate_filters_ndb(cls, filters, model_class):
142    """Validate ndb.Model filters."""
143    if not filters:
144      return
145
146    properties = model_class._properties
147
148    for f in filters:
149      prop, _, val = f
150      if prop not in properties:
151        raise errors.BadReaderParamsError(
152            "Property %s is not defined for entity type %s",
153            prop, model_class._get_kind())
154
155      # Validate the value of each filter. We need to know filters have
156      # valid value to carry out splits.
157      try:
158        properties[prop]._do_validate(val)
159      except db.BadValueError, e:
160        raise errors.BadReaderParamsError(e)
161
162