1"""
2Extensions to Django's model logic.
3"""
4
5import django.core.exceptions
6from django.db import connection
7from django.db import connections
8from django.db import models as dbmodels
9from django.db import transaction
10from django.db.models.sql import query
11import django.db.models.sql.where
12# TODO(akeshet): Replace with monarch stats once we know how to instrument rpc
13# handling with ts_mon.
14from autotest_lib.client.common_lib.cros.graphite import autotest_stats
15from autotest_lib.frontend.afe import rdb_model_extensions
16
17
18class ValidationError(django.core.exceptions.ValidationError):
19    """\
20    Data validation error in adding or updating an object. The associated
21    value is a dictionary mapping field names to error strings.
22    """
23
24def _quote_name(name):
25    """Shorthand for connection.ops.quote_name()."""
26    return connection.ops.quote_name(name)
27
28
29class LeasedHostManager(dbmodels.Manager):
30    """Query manager for unleased, unlocked hosts.
31    """
32    def get_query_set(self):
33        return (super(LeasedHostManager, self).get_query_set().filter(
34                leased=0, locked=0))
35
36
37class ExtendedManager(dbmodels.Manager):
38    """\
39    Extended manager supporting subquery filtering.
40    """
41
42    class CustomQuery(query.Query):
43        def __init__(self, *args, **kwargs):
44            super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
45            self._custom_joins = []
46
47
48        def clone(self, klass=None, **kwargs):
49            obj = super(ExtendedManager.CustomQuery, self).clone(klass)
50            obj._custom_joins = list(self._custom_joins)
51            return obj
52
53
54        def combine(self, rhs, connector):
55            super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
56            if hasattr(rhs, '_custom_joins'):
57                self._custom_joins.extend(rhs._custom_joins)
58
59
60        def add_custom_join(self, table, condition, join_type,
61                            condition_values=(), alias=None):
62            if alias is None:
63                alias = table
64            join_dict = dict(table=table,
65                             condition=condition,
66                             condition_values=condition_values,
67                             join_type=join_type,
68                             alias=alias)
69            self._custom_joins.append(join_dict)
70
71
72        @classmethod
73        def convert_query(self, query_set):
74            """
75            Convert the query set's "query" attribute to a CustomQuery.
76            """
77            # Make a copy of the query set
78            query_set = query_set.all()
79            query_set.query = query_set.query.clone(
80                    klass=ExtendedManager.CustomQuery,
81                    _custom_joins=[])
82            return query_set
83
84
85    class _WhereClause(object):
86        """Object allowing us to inject arbitrary SQL into Django queries.
87
88        By using this instead of extra(where=...), we can still freely combine
89        queries with & and |.
90        """
91        def __init__(self, clause, values=()):
92            self._clause = clause
93            self._values = values
94
95
96        def as_sql(self, qn=None, connection=None):
97            return self._clause, self._values
98
99
100        def relabel_aliases(self, change_map):
101            return
102
103
104    def add_join(self, query_set, join_table, join_key, join_condition='',
105                 join_condition_values=(), join_from_key=None, alias=None,
106                 suffix='', exclude=False, force_left_join=False):
107        """Add a join to query_set.
108
109        Join looks like this:
110                (INNER|LEFT) JOIN <join_table> AS <alias>
111                    ON (<this table>.<join_from_key> = <join_table>.<join_key>
112                        and <join_condition>)
113
114        @param join_table table to join to
115        @param join_key field referencing back to this model to use for the join
116        @param join_condition extra condition for the ON clause of the join
117        @param join_condition_values values to substitute into join_condition
118        @param join_from_key column on this model to join from.
119        @param alias alias to use for for join
120        @param suffix suffix to add to join_table for the join alias, if no
121                alias is provided
122        @param exclude if true, exclude rows that match this join (will use a
123        LEFT OUTER JOIN and an appropriate WHERE condition)
124        @param force_left_join - if true, a LEFT OUTER JOIN will be used
125        instead of an INNER JOIN regardless of other options
126        """
127        join_from_table = query_set.model._meta.db_table
128        if join_from_key is None:
129            join_from_key = self.model._meta.pk.name
130        if alias is None:
131            alias = join_table + suffix
132        full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
133        full_join_condition = '%s = %s.%s' % (full_join_key,
134                                              _quote_name(join_from_table),
135                                              _quote_name(join_from_key))
136        if join_condition:
137            full_join_condition += ' AND (' + join_condition + ')'
138        if exclude or force_left_join:
139            join_type = query_set.query.LOUTER
140        else:
141            join_type = query_set.query.INNER
142
143        query_set = self.CustomQuery.convert_query(query_set)
144        query_set.query.add_custom_join(join_table,
145                                        full_join_condition,
146                                        join_type,
147                                        condition_values=join_condition_values,
148                                        alias=alias)
149
150        if exclude:
151            query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
152
153        return query_set
154
155
156    def _info_for_many_to_one_join(self, field, join_to_query, alias):
157        """
158        @param field: the ForeignKey field on the related model
159        @param join_to_query: the query over the related model that we're
160                joining to
161        @param alias: alias of joined table
162        """
163        info = {}
164        rhs_table = join_to_query.model._meta.db_table
165        info['rhs_table'] = rhs_table
166        info['rhs_column'] = field.column
167        info['lhs_column'] = field.rel.get_related_field().column
168        rhs_where = join_to_query.query.where
169        rhs_where.relabel_aliases({rhs_table: alias})
170        compiler = join_to_query.query.get_compiler(using=join_to_query.db)
171        initial_clause, values = compiler.as_sql()
172        # initial_clause is compiled from `join_to_query`, which is a SELECT
173        # query returns at most one record. For it to be used in WHERE clause,
174        # it must be converted to a boolean value using EXISTS.
175        all_clauses = ('EXISTS (%s)' % initial_clause,)
176        if hasattr(join_to_query.query, 'extra_where'):
177            all_clauses += join_to_query.query.extra_where
178        info['where_clause'] = (
179                    ' AND '.join('(%s)' % clause for clause in all_clauses))
180        info['values'] = values
181        return info
182
183
184    def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
185                                    m2m_is_on_this_model):
186        """
187        @param m2m_field: a Django field representing the M2M relationship.
188                It uses a pivot table with the following structure:
189                this model table <---> M2M pivot table <---> joined model table
190        @param join_to_query: the query over the related model that we're
191                joining to.
192        @param alias: alias of joined table
193        """
194        if m2m_is_on_this_model:
195            # referenced field on this model
196            lhs_id_field = self.model._meta.pk
197            # foreign key on the pivot table referencing lhs_id_field
198            m2m_lhs_column = m2m_field.m2m_column_name()
199            # foreign key on the pivot table referencing rhd_id_field
200            m2m_rhs_column = m2m_field.m2m_reverse_name()
201            # referenced field on related model
202            rhs_id_field = m2m_field.rel.get_related_field()
203        else:
204            lhs_id_field = m2m_field.rel.get_related_field()
205            m2m_lhs_column = m2m_field.m2m_reverse_name()
206            m2m_rhs_column = m2m_field.m2m_column_name()
207            rhs_id_field = join_to_query.model._meta.pk
208
209        info = {}
210        info['rhs_table'] = m2m_field.m2m_db_table()
211        info['rhs_column'] = m2m_lhs_column
212        info['lhs_column'] = lhs_id_field.column
213
214        # select the ID of related models relevant to this join.  we can only do
215        # a single join, so we need to gather this information up front and
216        # include it in the join condition.
217        rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
218        assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
219                                   'match a single related object.')
220        rhs_id = rhs_ids[0]
221
222        info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
223                                               _quote_name(m2m_rhs_column),
224                                               rhs_id)
225        info['values'] = ()
226        return info
227
228
229    def join_custom_field(self, query_set, join_to_query, alias,
230                          left_join=True):
231        """Join to a related model to create a custom field in the given query.
232
233        This method is used to construct a custom field on the given query based
234        on a many-valued relationsip.  join_to_query should be a simple query
235        (no joins) on the related model which returns at most one related row
236        per instance of this model.
237
238        For many-to-one relationships, the joined table contains the matching
239        row from the related model it one is related, NULL otherwise.
240
241        For many-to-many relationships, the joined table contains the matching
242        row if it's related, NULL otherwise.
243        """
244        relationship_type, field = self.determine_relationship(
245                join_to_query.model)
246
247        if relationship_type == self.MANY_TO_ONE:
248            info = self._info_for_many_to_one_join(field, join_to_query, alias)
249        elif relationship_type == self.M2M_ON_RELATED_MODEL:
250            info = self._info_for_many_to_many_join(
251                    m2m_field=field, join_to_query=join_to_query, alias=alias,
252                    m2m_is_on_this_model=False)
253        elif relationship_type ==self.M2M_ON_THIS_MODEL:
254            info = self._info_for_many_to_many_join(
255                    m2m_field=field, join_to_query=join_to_query, alias=alias,
256                    m2m_is_on_this_model=True)
257
258        return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
259                             join_from_key=info['lhs_column'],
260                             join_condition=info['where_clause'],
261                             join_condition_values=info['values'],
262                             alias=alias,
263                             force_left_join=left_join)
264
265
266    def add_where(self, query_set, where, values=()):
267        query_set = query_set.all()
268        query_set.query.where.add(self._WhereClause(where, values),
269                                  django.db.models.sql.where.AND)
270        return query_set
271
272
273    def _get_quoted_field(self, table, field):
274        return _quote_name(table) + '.' + _quote_name(field)
275
276
277    def get_key_on_this_table(self, key_field=None):
278        if key_field is None:
279            # default to primary key
280            key_field = self.model._meta.pk.column
281        return self._get_quoted_field(self.model._meta.db_table, key_field)
282
283
284    def escape_user_sql(self, sql):
285        return sql.replace('%', '%%')
286
287
288    def _custom_select_query(self, query_set, selects):
289        """Execute a custom select query.
290
291        @param query_set: query set as returned by query_objects.
292        @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
293
294        @returns: Result of the query as returned by cursor.fetchall().
295        """
296        compiler = query_set.query.get_compiler(using=query_set.db)
297        sql, params = compiler.as_sql()
298        from_ = sql[sql.find(' FROM'):]
299
300        if query_set.query.distinct:
301            distinct = 'DISTINCT '
302        else:
303            distinct = ''
304
305        sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
306        # Chose the connection that's responsible for this type of object
307        cursor = connections[query_set.db].cursor()
308        cursor.execute(sql_query, params)
309        return cursor.fetchall()
310
311
312    def _is_relation_to(self, field, model_class):
313        return field.rel and field.rel.to is model_class
314
315
316    MANY_TO_ONE = object()
317    M2M_ON_RELATED_MODEL = object()
318    M2M_ON_THIS_MODEL = object()
319
320    def determine_relationship(self, related_model):
321        """
322        Determine the relationship between this model and related_model.
323
324        related_model must have some sort of many-valued relationship to this
325        manager's model.
326        @returns (relationship_type, field), where relationship_type is one of
327                MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
328                is the Django field object for the relationship.
329        """
330        # look for a foreign key field on related_model relating to this model
331        for field in related_model._meta.fields:
332            if self._is_relation_to(field, self.model):
333                return self.MANY_TO_ONE, field
334
335        # look for an M2M field on related_model relating to this model
336        for field in related_model._meta.many_to_many:
337            if self._is_relation_to(field, self.model):
338                return self.M2M_ON_RELATED_MODEL, field
339
340        # maybe this model has the many-to-many field
341        for field in self.model._meta.many_to_many:
342            if self._is_relation_to(field, related_model):
343                return self.M2M_ON_THIS_MODEL, field
344
345        raise ValueError('%s has no relation to %s' %
346                         (related_model, self.model))
347
348
349    def _get_pivot_iterator(self, base_objects_by_id, related_model):
350        """
351        Determine the relationship between this model and related_model, and
352        return a pivot iterator.
353        @param base_objects_by_id: dict of instances of this model indexed by
354        their IDs
355        @returns a pivot iterator, which yields a tuple (base_object,
356        related_object) for each relationship between a base object and a
357        related object.  all base_object instances come from base_objects_by_id.
358        Note -- this depends on Django model internals.
359        """
360        relationship_type, field = self.determine_relationship(related_model)
361        if relationship_type == self.MANY_TO_ONE:
362            return self._many_to_one_pivot(base_objects_by_id,
363                                           related_model, field)
364        elif relationship_type == self.M2M_ON_RELATED_MODEL:
365            return self._many_to_many_pivot(
366                    base_objects_by_id, related_model, field.m2m_db_table(),
367                    field.m2m_reverse_name(), field.m2m_column_name())
368        else:
369            assert relationship_type == self.M2M_ON_THIS_MODEL
370            return self._many_to_many_pivot(
371                    base_objects_by_id, related_model, field.m2m_db_table(),
372                    field.m2m_column_name(), field.m2m_reverse_name())
373
374
375    def _many_to_one_pivot(self, base_objects_by_id, related_model,
376                           foreign_key_field):
377        """
378        @returns a pivot iterator - see _get_pivot_iterator()
379        """
380        filter_data = {foreign_key_field.name + '__pk__in':
381                       base_objects_by_id.keys()}
382        for related_object in related_model.objects.filter(**filter_data):
383            # lookup base object in the dict, rather than grabbing it from the
384            # related object.  we need to return instances from the dict, not
385            # fresh instances of the same models (and grabbing model instances
386            # from the related models incurs a DB query each time).
387            base_object_id = getattr(related_object, foreign_key_field.attname)
388            base_object = base_objects_by_id[base_object_id]
389            yield base_object, related_object
390
391
392    def _query_pivot_table(self, base_objects_by_id, pivot_table,
393                           pivot_from_field, pivot_to_field, related_model):
394        """
395        @param id_list list of IDs of self.model objects to include
396        @param pivot_table the name of the pivot table
397        @param pivot_from_field a field name on pivot_table referencing
398        self.model
399        @param pivot_to_field a field name on pivot_table referencing the
400        related model.
401        @param related_model the related model
402
403        @returns pivot list of IDs (base_id, related_id)
404        """
405        query = """
406        SELECT %(from_field)s, %(to_field)s
407        FROM %(table)s
408        WHERE %(from_field)s IN (%(id_list)s)
409        """ % dict(from_field=pivot_from_field,
410                   to_field=pivot_to_field,
411                   table=pivot_table,
412                   id_list=','.join(str(id_) for id_
413                                    in base_objects_by_id.iterkeys()))
414
415        # Chose the connection that's responsible for this type of object
416        # The databases for related_model and the current model will always
417        # be the same, related_model is just easier to obtain here because
418        # self is only a ExtendedManager, not the object.
419        cursor = connections[related_model.objects.db].cursor()
420        cursor.execute(query)
421        return cursor.fetchall()
422
423
424    def _many_to_many_pivot(self, base_objects_by_id, related_model,
425                            pivot_table, pivot_from_field, pivot_to_field):
426        """
427        @param pivot_table: see _query_pivot_table
428        @param pivot_from_field: see _query_pivot_table
429        @param pivot_to_field: see _query_pivot_table
430        @returns a pivot iterator - see _get_pivot_iterator()
431        """
432        id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
433                                           pivot_from_field, pivot_to_field,
434                                           related_model)
435
436        all_related_ids = list(set(related_id for base_id, related_id
437                                   in id_pivot))
438        related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
439
440        for base_id, related_id in id_pivot:
441            yield base_objects_by_id[base_id], related_objects_by_id[related_id]
442
443
444    def populate_relationships(self, base_objects, related_model,
445                               related_list_name):
446        """
447        For each instance of this model in base_objects, add a field named
448        related_list_name listing all the related objects of type related_model.
449        related_model must be in a many-to-one or many-to-many relationship with
450        this model.
451        @param base_objects - list of instances of this model
452        @param related_model - model class related to this model
453        @param related_list_name - attribute name in which to store the related
454        object list.
455        """
456        if not base_objects:
457            # if we don't bail early, we'll get a SQL error later
458            return
459
460        base_objects_by_id = dict((base_object._get_pk_val(), base_object)
461                                  for base_object in base_objects)
462        pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
463                                                  related_model)
464
465        for base_object in base_objects:
466            setattr(base_object, related_list_name, [])
467
468        for base_object, related_object in pivot_iterator:
469            getattr(base_object, related_list_name).append(related_object)
470
471
472class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
473    """
474    QuerySet that handles delete() properly for models with an "invalid" bit
475    """
476    def delete(self):
477        for model in self:
478            model.delete()
479
480
481class ModelWithInvalidManager(ExtendedManager):
482    """
483    Manager for objects with an "invalid" bit
484    """
485    def get_query_set(self):
486        return ModelWithInvalidQuerySet(self.model)
487
488
489class ValidObjectsManager(ModelWithInvalidManager):
490    """
491    Manager returning only objects with invalid=False.
492    """
493    def get_query_set(self):
494        queryset = super(ValidObjectsManager, self).get_query_set()
495        return queryset.filter(invalid=False)
496
497
498class ModelExtensions(rdb_model_extensions.ModelValidators):
499    """\
500    Mixin with convenience functions for models, built on top of
501    the model validators in rdb_model_extensions.
502    """
503    # TODO: at least some of these functions really belong in a custom
504    # Manager class
505
506
507    SERIALIZATION_LINKS_TO_FOLLOW = set()
508    """
509    To be able to send jobs and hosts to shards, it's necessary to find their
510    dependencies.
511    The most generic approach for this would be to traverse all relationships
512    to other objects recursively. This would list all objects that are related
513    in any way.
514    But this approach finds too many objects: If a host should be transferred,
515    all it's relationships would be traversed. This would find an acl group.
516    If then the acl group's relationships are traversed, the relationship
517    would be followed backwards and many other hosts would be found.
518
519    This mapping tells that algorithm which relations to follow explicitly.
520    """
521
522
523    SERIALIZATION_LINKS_TO_KEEP = set()
524    """This set stores foreign keys which we don't want to follow, but
525    still want to include in the serialized dictionary. For
526    example, we follow the relationship `Host.hostattribute_set`,
527    but we do not want to follow `HostAttributes.host_id` back to
528    to Host, which would otherwise lead to a circle. However, we still
529    like to serialize HostAttribute.`host_id`."""
530
531    SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
532    """
533    On deserializion, if the object to persist already exists, local fields
534    will only be updated, if their name is in this set.
535    """
536
537
538    @classmethod
539    def convert_human_readable_values(cls, data, to_human_readable=False):
540        """\
541        Performs conversions on user-supplied field data, to make it
542        easier for users to pass human-readable data.
543
544        For all fields that have choice sets, convert their values
545        from human-readable strings to enum values, if necessary.  This
546        allows users to pass strings instead of the corresponding
547        integer values.
548
549        For all foreign key fields, call smart_get with the supplied
550        data.  This allows the user to pass either an ID value or
551        the name of the object as a string.
552
553        If to_human_readable=True, perform the inverse - i.e. convert
554        numeric values to human readable values.
555
556        This method modifies data in-place.
557        """
558        field_dict = cls.get_field_dict()
559        for field_name in data:
560            if field_name not in field_dict or data[field_name] is None:
561                continue
562            field_obj = field_dict[field_name]
563            # convert enum values
564            if field_obj.choices:
565                for choice_data in field_obj.choices:
566                    # choice_data is (value, name)
567                    if to_human_readable:
568                        from_val, to_val = choice_data
569                    else:
570                        to_val, from_val = choice_data
571                    if from_val == data[field_name]:
572                        data[field_name] = to_val
573                        break
574            # convert foreign key values
575            elif field_obj.rel:
576                dest_obj = field_obj.rel.to.smart_get(data[field_name],
577                                                      valid_only=False)
578                if to_human_readable:
579                    # parameterized_jobs do not have a name_field
580                    if (field_name != 'parameterized_job' and
581                        dest_obj.name_field is not None):
582                        data[field_name] = getattr(dest_obj,
583                                                   dest_obj.name_field)
584                else:
585                    data[field_name] = dest_obj
586
587
588
589
590    def _validate_unique(self):
591        """\
592        Validate that unique fields are unique.  Django manipulators do
593        this too, but they're a huge pain to use manually.  Trust me.
594        """
595        errors = {}
596        cls = type(self)
597        field_dict = self.get_field_dict()
598        manager = cls.get_valid_manager()
599        for field_name, field_obj in field_dict.iteritems():
600            if not field_obj.unique:
601                continue
602
603            value = getattr(self, field_name)
604            if value is None and field_obj.auto_created:
605                # don't bother checking autoincrement fields about to be
606                # generated
607                continue
608
609            existing_objs = manager.filter(**{field_name : value})
610            num_existing = existing_objs.count()
611
612            if num_existing == 0:
613                continue
614            if num_existing == 1 and existing_objs[0].id == self.id:
615                continue
616            errors[field_name] = (
617                'This value must be unique (%s)' % (value))
618        return errors
619
620
621    def _validate(self):
622        """
623        First coerces all fields on this instance to their proper Python types.
624        Then runs validation on every field. Returns a dictionary of
625        field_name -> error_list.
626
627        Based on validate() from django.db.models.Model in Django 0.96, which
628        was removed in Django 1.0. It should reappear in a later version. See:
629            http://code.djangoproject.com/ticket/6845
630        """
631        error_dict = {}
632        for f in self._meta.fields:
633            try:
634                python_value = f.to_python(
635                    getattr(self, f.attname, f.get_default()))
636            except django.core.exceptions.ValidationError, e:
637                error_dict[f.name] = str(e)
638                continue
639
640            if not f.blank and not python_value:
641                error_dict[f.name] = 'This field is required.'
642                continue
643
644            setattr(self, f.attname, python_value)
645
646        return error_dict
647
648
649    def do_validate(self):
650        errors = self._validate()
651        unique_errors = self._validate_unique()
652        for field_name, error in unique_errors.iteritems():
653            errors.setdefault(field_name, error)
654        if errors:
655            raise ValidationError(errors)
656
657
658    # actually (externally) useful methods follow
659
660    @classmethod
661    def add_object(cls, data={}, **kwargs):
662        """\
663        Returns a new object created with the given data (a dictionary
664        mapping field names to values). Merges any extra keyword args
665        into data.
666        """
667        data = dict(data)
668        data.update(kwargs)
669        data = cls.prepare_data_args(data)
670        cls.convert_human_readable_values(data)
671        data = cls.provide_default_values(data)
672
673        obj = cls(**data)
674        obj.do_validate()
675        obj.save()
676        return obj
677
678
679    def update_object(self, data={}, **kwargs):
680        """\
681        Updates the object with the given data (a dictionary mapping
682        field names to values).  Merges any extra keyword args into
683        data.
684        """
685        data = dict(data)
686        data.update(kwargs)
687        data = self.prepare_data_args(data)
688        self.convert_human_readable_values(data)
689        for field_name, value in data.iteritems():
690            setattr(self, field_name, value)
691        self.do_validate()
692        self.save()
693
694
695    # see query_objects()
696    _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
697                            'extra_args', 'extra_where', 'no_distinct')
698
699
700    @classmethod
701    def _extract_special_params(cls, filter_data):
702        """
703        @returns a tuple of dicts (special_params, regular_filters), where
704        special_params contains the parameters we handle specially and
705        regular_filters is the remaining data to be handled by Django.
706        """
707        regular_filters = dict(filter_data)
708        special_params = {}
709        for key in cls._SPECIAL_FILTER_KEYS:
710            if key in regular_filters:
711                special_params[key] = regular_filters.pop(key)
712        return special_params, regular_filters
713
714
715    @classmethod
716    def apply_presentation(cls, query, filter_data):
717        """
718        Apply presentation parameters -- sorting and paging -- to the given
719        query.
720        @returns new query with presentation applied
721        """
722        special_params, _ = cls._extract_special_params(filter_data)
723        sort_by = special_params.get('sort_by', None)
724        if sort_by:
725            assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
726            query = query.extra(order_by=sort_by)
727
728        query_start = special_params.get('query_start', None)
729        query_limit = special_params.get('query_limit', None)
730        if query_start is not None:
731            if query_limit is None:
732                raise ValueError('Cannot pass query_start without query_limit')
733            # query_limit is passed as a page size
734            query_limit += query_start
735        return query[query_start:query_limit]
736
737
738    @classmethod
739    def query_objects(cls, filter_data, valid_only=True, initial_query=None,
740                      apply_presentation=True):
741        """\
742        Returns a QuerySet object for querying the given model_class
743        with the given filter_data.  Optional special arguments in
744        filter_data include:
745        -query_start: index of first return to return
746        -query_limit: maximum number of results to return
747        -sort_by: list of fields to sort on.  prefixing a '-' onto a
748         field name changes the sort to descending order.
749        -extra_args: keyword args to pass to query.extra() (see Django
750         DB layer documentation)
751        -extra_where: extra WHERE clause to append
752        -no_distinct: if True, a DISTINCT will not be added to the SELECT
753        """
754        special_params, regular_filters = cls._extract_special_params(
755                filter_data)
756
757        if initial_query is None:
758            if valid_only:
759                initial_query = cls.get_valid_manager()
760            else:
761                initial_query = cls.objects
762
763        query = initial_query.filter(**regular_filters)
764
765        use_distinct = not special_params.get('no_distinct', False)
766        if use_distinct:
767            query = query.distinct()
768
769        extra_args = special_params.get('extra_args', {})
770        extra_where = special_params.get('extra_where', None)
771        if extra_where:
772            # escape %'s
773            extra_where = cls.objects.escape_user_sql(extra_where)
774            extra_args.setdefault('where', []).append(extra_where)
775        if extra_args:
776            query = query.extra(**extra_args)
777            # TODO: Use readonly connection for these queries.
778            # This has been disabled, because it's not used anyway, as the
779            # configured readonly user is the same as the real user anyway.
780
781        if apply_presentation:
782            query = cls.apply_presentation(query, filter_data)
783
784        return query
785
786
787    @classmethod
788    def query_count(cls, filter_data, initial_query=None):
789        """\
790        Like query_objects, but retreive only the count of results.
791        """
792        filter_data.pop('query_start', None)
793        filter_data.pop('query_limit', None)
794        query = cls.query_objects(filter_data, initial_query=initial_query)
795        return query.count()
796
797
798    @classmethod
799    def clean_object_dicts(cls, field_dicts):
800        """\
801        Take a list of dicts corresponding to object (as returned by
802        query.values()) and clean the data to be more suitable for
803        returning to the user.
804        """
805        for field_dict in field_dicts:
806            cls.clean_foreign_keys(field_dict)
807            cls._convert_booleans(field_dict)
808            cls.convert_human_readable_values(field_dict,
809                                              to_human_readable=True)
810
811
812    @classmethod
813    def list_objects(cls, filter_data, initial_query=None):
814        """\
815        Like query_objects, but return a list of dictionaries.
816        """
817        query = cls.query_objects(filter_data, initial_query=initial_query)
818        extra_fields = query.query.extra_select.keys()
819        field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
820                       for model_object in query]
821        return field_dicts
822
823
824    @classmethod
825    def smart_get(cls, id_or_name, valid_only=True):
826        """\
827        smart_get(integer) -> get object by ID
828        smart_get(string) -> get object by name_field
829        """
830        if valid_only:
831            manager = cls.get_valid_manager()
832        else:
833            manager = cls.objects
834
835        if isinstance(id_or_name, (int, long)):
836            return manager.get(pk=id_or_name)
837        if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
838            return manager.get(**{cls.name_field : id_or_name})
839        raise ValueError(
840            'Invalid positional argument: %s (%s)' % (id_or_name,
841                                                      type(id_or_name)))
842
843
844    @classmethod
845    def smart_get_bulk(cls, id_or_name_list):
846        invalid_inputs = []
847        result_objects = []
848        for id_or_name in id_or_name_list:
849            try:
850                result_objects.append(cls.smart_get(id_or_name))
851            except cls.DoesNotExist:
852                invalid_inputs.append(id_or_name)
853        if invalid_inputs:
854            raise cls.DoesNotExist('The following %ss do not exist: %s'
855                                   % (cls.__name__.lower(),
856                                      ', '.join(invalid_inputs)))
857        return result_objects
858
859
860    def get_object_dict(self, extra_fields=None):
861        """\
862        Return a dictionary mapping fields to this object's values.  @param
863        extra_fields: list of extra attribute names to include, in addition to
864        the fields defined on this object.
865        """
866        fields = self.get_field_dict().keys()
867        if extra_fields:
868            fields += extra_fields
869        object_dict = dict((field_name, getattr(self, field_name))
870                           for field_name in fields)
871        self.clean_object_dicts([object_dict])
872        self._postprocess_object_dict(object_dict)
873        return object_dict
874
875
876    def _postprocess_object_dict(self, object_dict):
877        """For subclasses to override."""
878        pass
879
880
881    @classmethod
882    def get_valid_manager(cls):
883        return cls.objects
884
885
886    def _record_attributes(self, attributes):
887        """
888        See on_attribute_changed.
889        """
890        assert not isinstance(attributes, basestring)
891        self._recorded_attributes = dict((attribute, getattr(self, attribute))
892                                         for attribute in attributes)
893
894
895    def _check_for_updated_attributes(self):
896        """
897        See on_attribute_changed.
898        """
899        for attribute, original_value in self._recorded_attributes.iteritems():
900            new_value = getattr(self, attribute)
901            if original_value != new_value:
902                self.on_attribute_changed(attribute, original_value)
903        self._record_attributes(self._recorded_attributes.keys())
904
905
906    def on_attribute_changed(self, attribute, old_value):
907        """
908        Called whenever an attribute is updated.  To be overridden.
909
910        To use this method, you must:
911        * call _record_attributes() from __init__() (after making the super
912        call) with a list of attributes for which you want to be notified upon
913        change.
914        * call _check_for_updated_attributes() from save().
915        """
916        pass
917
918
919    def serialize(self, include_dependencies=True):
920        """Serializes the object with dependencies.
921
922        The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
923        this function will serialize with the object.
924
925        @param include_dependencies: Whether or not to follow relations to
926                                     objects this object depends on.
927                                     This parameter is used when uploading
928                                     jobs from a shard to the master, as the
929                                     master already has all the dependent
930                                     objects.
931
932        @returns: Dictionary representation of the object.
933        """
934        serialized = {}
935        timer = autotest_stats.Timer('serialize_latency.%s' % (
936                type(self).__name__))
937        with timer.get_client('local'):
938            for field in self._meta.concrete_model._meta.local_fields:
939                if field.rel is None:
940                    serialized[field.name] = field._get_val_from_obj(self)
941                elif field.name in self.SERIALIZATION_LINKS_TO_KEEP:
942                    # attname will contain "_id" suffix for foreign keys,
943                    # e.g. HostAttribute.host will be serialized as 'host_id'.
944                    # Use it for easy deserialization.
945                    serialized[field.attname] = field._get_val_from_obj(self)
946
947        if include_dependencies:
948            with timer.get_client('related'):
949                for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
950                    serialized[link] = self._serialize_relation(link)
951
952        return serialized
953
954
955    def _serialize_relation(self, link):
956        """Serializes dependent objects given the name of the relation.
957
958        @param link: Name of the relation to take objects from.
959
960        @returns For To-Many relationships a list of the serialized related
961            objects, for To-One relationships the serialized related object.
962        """
963        try:
964            attr = getattr(self, link)
965        except AttributeError:
966            # One-To-One relationships that point to None may raise this
967            return None
968
969        if attr is None:
970            return None
971        if hasattr(attr, 'all'):
972            return [obj.serialize() for obj in attr.all()]
973        return attr.serialize()
974
975
976    @classmethod
977    def _split_local_from_foreign_values(cls, data):
978        """This splits local from foreign values in a serialized object.
979
980        @param data: The serialized object.
981
982        @returns A tuple of two lists, both containing tuples in the form
983                 (link_name, link_value). The first list contains all links
984                 for local fields, the second one contains those for foreign
985                 fields/objects.
986        """
987        links_to_local_values, links_to_related_values = [], []
988        for link, value in data.iteritems():
989            if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
990                # It's a foreign key
991                links_to_related_values.append((link, value))
992            else:
993                # It's a local attribute or a foreign key
994                # we don't want to follow.
995                links_to_local_values.append((link, value))
996        return links_to_local_values, links_to_related_values
997
998
999    @classmethod
1000    def _filter_update_allowed_fields(cls, data):
1001        """Filters data and returns only files that updates are allowed on.
1002
1003        This is i.e. needed for syncing aborted bits from the master to shards.
1004
1005        Local links are only allowed to be updated, if they are in
1006        SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1007        Overwriting existing values is allowed in order to be able to sync i.e.
1008        the aborted bit from the master to a shard.
1009
1010        The whitelisting mechanism is in place to prevent overwriting local
1011        status: If all fields were overwritten, jobs would be completely be
1012        set back to their original (unstarted) state.
1013
1014        @param data: List with tuples of the form (link_name, link_value), as
1015                     returned by _split_local_from_foreign_values.
1016
1017        @returns List of the same format as data, but only containing data for
1018                 fields that updates are allowed on.
1019        """
1020        return [pair for pair in data
1021                if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1022
1023
1024    @classmethod
1025    def delete_matching_record(cls, **filter_args):
1026        """Delete records matching the filter.
1027
1028        @param filter_args: Arguments for the django filter
1029                used to locate the record to delete.
1030        """
1031        try:
1032            existing_record = cls.objects.get(**filter_args)
1033        except cls.DoesNotExist:
1034            return
1035        existing_record.delete()
1036
1037
1038    def _deserialize_local(self, data):
1039        """Set local attributes from a list of tuples.
1040
1041        @param data: List of tuples like returned by
1042                     _split_local_from_foreign_values.
1043        """
1044        if not data:
1045            return
1046
1047        for link, value in data:
1048            setattr(self, link, value)
1049        # Overwridden save() methods are prone to errors, so don't execute them.
1050        # This is because:
1051        # - the overwritten methods depend on ACL groups that don't yet exist
1052        #   and don't handle errors
1053        # - the overwritten methods think this object already exists in the db
1054        #   because the id is already set
1055        super(type(self), self).save()
1056
1057
1058    def _deserialize_relations(self, data):
1059        """Set foreign attributes from a list of tuples.
1060
1061        This deserialized the related objects using their own deserialize()
1062        function and then sets the relation.
1063
1064        @param data: List of tuples like returned by
1065                     _split_local_from_foreign_values.
1066        """
1067        for link, value in data:
1068            self._deserialize_relation(link, value)
1069        # See comment in _deserialize_local
1070        super(type(self), self).save()
1071
1072
1073    @classmethod
1074    def get_record(cls, data):
1075        """Retrieve a record with the data in the given input arg.
1076
1077        @param data: A dictionary containing the information to use in a query
1078                for data. If child models have different constraints of
1079                uniqueness they should override this model.
1080
1081        @return: An object with matching data.
1082
1083        @raises DoesNotExist: If a record with the given data doesn't exist.
1084        """
1085        return cls.objects.get(id=data['id'])
1086
1087
1088    @classmethod
1089    def deserialize(cls, data):
1090        """Recursively deserializes and saves an object with it's dependencies.
1091
1092        This takes the result of the serialize method and creates objects
1093        in the database that are just like the original.
1094
1095        If an object of the same type with the same id already exists, it's
1096        local values will be left untouched, unless they are explicitly
1097        whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1098
1099        Deserialize will always recursively propagate to all related objects
1100        present in data though.
1101        I.e. this is necessary to add users to an already existing acl-group.
1102
1103        @param data: Representation of an object and its dependencies, as
1104                     returned by serialize.
1105
1106        @returns: The object represented by data if it didn't exist before,
1107                  otherwise the object that existed before and has the same type
1108                  and id as the one described by data.
1109        """
1110        if data is None:
1111            return None
1112
1113        local, related = cls._split_local_from_foreign_values(data)
1114        try:
1115            instance = cls.get_record(data)
1116            local = cls._filter_update_allowed_fields(local)
1117        except cls.DoesNotExist:
1118            instance = cls()
1119
1120        timer = autotest_stats.Timer('deserialize_latency.%s' % (
1121                type(instance).__name__))
1122        with timer.get_client('local'):
1123            instance._deserialize_local(local)
1124        with timer.get_client('related'):
1125            instance._deserialize_relations(related)
1126
1127        return instance
1128
1129
1130    def sanity_check_update_from_shard(self, shard, updated_serialized,
1131                                       *args, **kwargs):
1132        """Check if an update sent from a shard is legitimate.
1133
1134        @raises error.UnallowedRecordsSentToMaster if an update is not
1135                legitimate.
1136        """
1137        raise NotImplementedError(
1138            'sanity_check_update_from_shard must be implemented by subclass %s '
1139            'for type %s' % type(self))
1140
1141
1142    @transaction.commit_on_success
1143    def update_from_serialized(self, serialized):
1144        """Updates local fields of an existing object from a serialized form.
1145
1146        This is different than the normal deserialize() in the way that it
1147        does update local values, which deserialize doesn't, but doesn't
1148        recursively propagate to related objects, which deserialize() does.
1149
1150        The use case of this function is to update job records on the master
1151        after the jobs have been executed on a slave, as the master is not
1152        interested in updates for users, labels, specialtasks, etc.
1153
1154        @param serialized: Representation of an object and its dependencies, as
1155                           returned by serialize.
1156
1157        @raises ValueError: if serialized contains related objects, i.e. not
1158                            only local fields.
1159        """
1160        local, related = (
1161            self._split_local_from_foreign_values(serialized))
1162        if related:
1163            raise ValueError('Serialized must not contain foreign '
1164                             'objects: %s' % related)
1165
1166        self._deserialize_local(local)
1167
1168
1169    def custom_deserialize_relation(self, link, data):
1170        """Allows overriding the deserialization behaviour by subclasses."""
1171        raise NotImplementedError(
1172            'custom_deserialize_relation must be implemented by subclass %s '
1173            'for relation %s' % (type(self), link))
1174
1175
1176    def _deserialize_relation(self, link, data):
1177        """Deserializes related objects and sets references on this object.
1178
1179        Relations that point to a list of objects are handled automatically.
1180        For many-to-one or one-to-one relations custom_deserialize_relation
1181        must be overridden by the subclass.
1182
1183        Related objects are deserialized using their deserialize() method.
1184        Thereby they and their dependencies are created if they don't exist
1185        and saved to the database.
1186
1187        @param link: Name of the relation.
1188        @param data: Serialized representation of the related object(s).
1189                     This means a list of dictionaries for to-many relations,
1190                     just a dictionary for to-one relations.
1191        """
1192        field = getattr(self, link)
1193
1194        if field and hasattr(field, 'all'):
1195            self._deserialize_2m_relation(link, data, field.model)
1196        else:
1197            self.custom_deserialize_relation(link, data)
1198
1199
1200    def _deserialize_2m_relation(self, link, data, related_class):
1201        """Deserialize related objects for one to-many relationship.
1202
1203        @param link: Name of the relation.
1204        @param data: Serialized representation of the related objects.
1205                     This is a list with of dictionaries.
1206        @param related_class: A class representing a django model, with which
1207                              this class has a one-to-many relationship.
1208        """
1209        relation_set = getattr(self, link)
1210        if related_class == self.get_attribute_model():
1211            # When deserializing a model together with
1212            # its attributes, clear all the exising attributes to ensure
1213            # db consistency. Note 'update' won't be sufficient, as we also
1214            # want to remove any attributes that no longer exist in |data|.
1215            #
1216            # core_filters is a dictionary of filters, defines how
1217            # RelatedMangager would query for the 1-to-many relationship. E.g.
1218            # Host.objects.get(
1219            #     id=20).hostattribute_set.core_filters = {host_id:20}
1220            # We use it to delete objects related to the current object.
1221            related_class.objects.filter(**relation_set.core_filters).delete()
1222        for serialized in data:
1223            relation_set.add(related_class.deserialize(serialized))
1224
1225
1226    @classmethod
1227    def get_attribute_model(cls):
1228        """Return the attribute model.
1229
1230        Subclass with attribute-like model should override this to
1231        return the attribute model class. This method will be
1232        called by _deserialize_2m_relation to determine whether
1233        to clear the one-to-many relations first on deserialization of object.
1234        """
1235        return None
1236
1237
1238class ModelWithInvalid(ModelExtensions):
1239    """
1240    Overrides model methods save() and delete() to support invalidation in
1241    place of actual deletion.  Subclasses must have a boolean "invalid"
1242    field.
1243    """
1244
1245    def save(self, *args, **kwargs):
1246        first_time = (self.id is None)
1247        if first_time:
1248            # see if this object was previously added and invalidated
1249            my_name = getattr(self, self.name_field)
1250            filters = {self.name_field : my_name, 'invalid' : True}
1251            try:
1252                old_object = self.__class__.objects.get(**filters)
1253                self.resurrect_object(old_object)
1254            except self.DoesNotExist:
1255                # no existing object
1256                pass
1257
1258        super(ModelWithInvalid, self).save(*args, **kwargs)
1259
1260
1261    def resurrect_object(self, old_object):
1262        """
1263        Called when self is about to be saved for the first time and is actually
1264        "undeleting" a previously deleted object.  Can be overridden by
1265        subclasses to copy data as desired from the deleted entry (but this
1266        superclass implementation must normally be called).
1267        """
1268        self.id = old_object.id
1269
1270
1271    def clean_object(self):
1272        """
1273        This method is called when an object is marked invalid.
1274        Subclasses should override this to clean up relationships that
1275        should no longer exist if the object were deleted.
1276        """
1277        pass
1278
1279
1280    def delete(self):
1281        self.invalid = self.invalid
1282        assert not self.invalid
1283        self.invalid = True
1284        self.save()
1285        self.clean_object()
1286
1287
1288    @classmethod
1289    def get_valid_manager(cls):
1290        return cls.valid_objects
1291
1292
1293    class Manipulator(object):
1294        """
1295        Force default manipulators to look only at valid objects -
1296        otherwise they will match against invalid objects when checking
1297        uniqueness.
1298        """
1299        @classmethod
1300        def _prepare(cls, model):
1301            super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1302            cls.manager = model.valid_objects
1303
1304
1305class ModelWithAttributes(object):
1306    """
1307    Mixin class for models that have an attribute model associated with them.
1308    The attribute model is assumed to have its value field named "value".
1309    """
1310
1311    def _get_attribute_model_and_args(self, attribute):
1312        """
1313        Subclasses should override this to return a tuple (attribute_model,
1314        keyword_args), where attribute_model is a model class and keyword_args
1315        is a dict of args to pass to attribute_model.objects.get() to get an
1316        instance of the given attribute on this object.
1317        """
1318        raise NotImplementedError
1319
1320
1321    def set_attribute(self, attribute, value):
1322        attribute_model, get_args = self._get_attribute_model_and_args(
1323            attribute)
1324        attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1325        attribute_object.value = value
1326        attribute_object.save()
1327
1328
1329    def delete_attribute(self, attribute):
1330        attribute_model, get_args = self._get_attribute_model_and_args(
1331            attribute)
1332        try:
1333            attribute_model.objects.get(**get_args).delete()
1334        except attribute_model.DoesNotExist:
1335            pass
1336
1337
1338    def set_or_delete_attribute(self, attribute, value):
1339        if value is None:
1340            self.delete_attribute(attribute)
1341        else:
1342            self.set_attribute(attribute, value)
1343
1344
1345class ModelWithHashManager(dbmodels.Manager):
1346    """Manager for use with the ModelWithHash abstract model class"""
1347
1348    def create(self, **kwargs):
1349        raise Exception('ModelWithHash manager should use get_or_create() '
1350                        'instead of create()')
1351
1352
1353    def get_or_create(self, **kwargs):
1354        kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1355        return super(ModelWithHashManager, self).get_or_create(**kwargs)
1356
1357
1358class ModelWithHash(dbmodels.Model):
1359    """Superclass with methods for dealing with a hash column"""
1360
1361    the_hash = dbmodels.CharField(max_length=40, unique=True)
1362
1363    objects = ModelWithHashManager()
1364
1365    class Meta:
1366        abstract = True
1367
1368
1369    @classmethod
1370    def _compute_hash(cls, **kwargs):
1371        raise NotImplementedError('Subclasses must override _compute_hash()')
1372
1373
1374    def save(self, force_insert=False, **kwargs):
1375        """Prevents saving the model in most cases
1376
1377        We want these models to be immutable, so the generic save() operation
1378        will not work. These models should be instantiated through their the
1379        model.objects.get_or_create() method instead.
1380
1381        The exception is that save(force_insert=True) will be allowed, since
1382        that creates a new row. However, the preferred way to make instances of
1383        these models is through the get_or_create() method.
1384        """
1385        if not force_insert:
1386            # Allow a forced insert to happen; if it's a duplicate, the unique
1387            # constraint will catch it later anyways
1388            raise Exception('ModelWithHash is immutable')
1389        super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)
1390