validators.py revision d1603ecc4264a1c789e0de7e2c584ab83235a068
1# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Validators to verify if events conform to specified criteria."""
6
7
8'''
9How to add a new validator/gesture:
10(1) Implement a new validator class inheriting BaseValidator,
11(2) add proper method in mtb.Mtb class,
12(3) add the new validator in test_conf, and
13        'from validators import the_new_validator'
14    in alphabetical order, and
15(4) add the validator in relevant gestures; add a new gesture if necessary.
16
17The validator template is as follows:
18
19class XxxValidator(BaseValidator):
20    """Validator to check ...
21
22    Example:
23        To check ...
24          XxxValidator('<= 0.05, ~ +0.05', fingers=2)
25    """
26
27    def __init__(self, criteria_str, mf=None, fingers=1):
28        name = self.__class__.__name__
29        super(X..Validator, self).__init__(criteria_str, mf, name)
30        self.fingers = fingers
31
32    def check(self, packets, variation=None):
33        """Check ..."""
34        self.init_check(packets)
35        xxx = self.packets.xxx()
36        self.print_msg(...)
37        return (self.fc.mf.grade(...), self.msg_list)
38
39
40Note that it is also possible to instantiate a validator as
41          XxxValidator('<= 0.05, ~ +0.05', slot=0)
42
43    Difference between fingers and slot:
44      . When specifying 'fingers', e.g., fingers=2, the purpose is to pass
45        the information about how many fingers there are in the gesture. In
46        this case, the events in a specific slot is usually not important.
47        An example is to check how many fingers there are when making a click:
48            PhysicalClickValidator('== 0', fingers=2)
49      . When specifying 'slot', e.g., slot=0, the purpose is pass the slot
50        number to the validator to examine detailed events in that slot.
51        An example of such usage:
52            LinearityValidator('<= 0.03, ~ +0.07', slot=0)
53'''
54
55
56import copy
57import numpy as np
58import os
59import re
60
61import firmware_log
62import fuzzy
63import mtb
64
65from collections import namedtuple
66from inspect import isfunction
67
68from common_util import print_and_exit
69from firmware_constants import AXIS, GV, MTB, UNIT, VAL
70
71
72# Define the ratio of points taken at both ends of a line for edge tests.
73END_PERCENTAGE = 0.1
74
75# Define other constants below.
76VALIDATOR = 'Validator'
77
78
79def validate(packets, gesture, variation):
80    """Validate a single gesture."""
81    if packets is None:
82        return (None, None)
83
84    msg_list = []
85    score_list = []
86    vlogs = []
87    for validator in gesture.validators:
88        vlog = validator.check(packets, variation)
89        if vlog is None:
90            continue
91        vlogs.append(copy.deepcopy(vlog))
92        score = vlog.score
93
94        if score is not None:
95            score_list.append(score)
96            # save the validator messages
97            msg_validator_name = '%s' % vlog.name
98            msg_criteria = '    criteria_str: %s' % vlog.criteria
99            msg_score = 'score: %f' % score
100            msg_list.append(os.linesep)
101            msg_list.append(msg_validator_name)
102            msg_list += vlog.details
103            msg_list.append(msg_criteria)
104            msg_list.append(msg_score)
105
106    return (score_list, msg_list, vlogs)
107
108
109def get_short_name(validator_name):
110    """Get the short name of the validator.
111
112    E.g, the short name of LinearityValidator is Linearity.
113    """
114    return validator_name.split(VALIDATOR)[0]
115
116
117def get_validator_name(short_name):
118    """Convert the short_name to its corresponding validator name.
119
120    E.g, the validator_name of Linearity is LinearityValidator.
121    """
122    return short_name + VALIDATOR
123
124
125def get_base_name_and_segment(validator_name):
126    """Get the base name and segment of a validator.
127
128    Examples:
129        Ex 1: Linearity(BothEnds)Validator
130            return ('Linearity', 'BothEnds')
131        Ex 2: NoGapValidator
132            return ('NoGap', None)
133    """
134    if '(' in validator_name:
135        result = re.search('(.*)\((.*)\)%s' % VALIDATOR, validator_name)
136        return (result.group(1), result.group(2))
137    else:
138        return (get_short_name(validator_name), None)
139
140
141def get_derived_name(validator_name, segment):
142    """Get the derived name based on segment value.
143
144    Example:
145      validator_name: LinearityValidator
146      segment: Middle
147      derived_name: Linearity(Middle)Validator
148    """
149    short_name = get_short_name(validator_name)
150    derived_name = '%s(%s)%s' % (short_name, segment, VALIDATOR)
151    return derived_name
152
153
154def init_base_validator(device):
155    """Initialize the device for all the Validators to use"""
156    BaseValidator._device = device
157
158
159class BaseValidator(object):
160    """Base class of validators."""
161    aggregator = 'fuzzy.average'
162    _device = None
163
164    def __init__(self, criteria, mf=None, device=None, name=None):
165        self.criteria_str = criteria() if isfunction(criteria) else criteria
166        self.fc = fuzzy.FuzzyCriteria(self.criteria_str, mf=mf)
167        self.device = device if device else BaseValidator._device
168        self.packets = None
169        self.vlog = firmware_log.ValidatorLog()
170        self.vlog.name = name
171        self.vlog.criteria = self.criteria_str
172        self.mnprops = firmware_log.MetricNameProps()
173
174    def init_check(self, packets=None):
175        """Initialization before check() is called."""
176        self.packets = mtb.Mtb(device=self.device, packets=packets)
177        self.vlog.reset()
178
179    def _is_direction_in_variation(self, variation, directions):
180        """Is any element of directions list found in variation?"""
181        for direction in directions:
182            if direction in variation:
183                return True
184        return False
185
186    def is_horizontal(self, variation):
187        """Is the direction horizontal?"""
188        return self._is_direction_in_variation(variation,
189                                               GV.HORIZONTAL_DIRECTIONS)
190
191    def is_vertical(self, variation):
192        """Is the direction vertical?"""
193        return self._is_direction_in_variation(variation,
194                                               GV.VERTICAL_DIRECTIONS)
195
196    def is_diagonal(self, variation):
197        """Is the direction diagonal?"""
198        return self._is_direction_in_variation(variation,
199                                               GV.DIAGONAL_DIRECTIONS)
200
201    def get_direction(self, variation):
202        """Get the direction."""
203        # TODO(josephsih): raise an exception if a proper direction is not found
204        if self.is_horizontal(variation):
205            return GV.HORIZONTAL
206        elif self.is_vertical(variation):
207            return GV.VERTICAL
208        elif self.is_diagonal(variation):
209            return GV.DIAGONAL
210
211    def get_direction_in_variation(self, variation):
212        """Get the direction string from the variation list."""
213        if isinstance(variation, tuple):
214            for var in variation:
215                if var in GV.GESTURE_DIRECTIONS:
216                    return var
217        elif variation in GV.GESTURE_DIRECTIONS:
218            return variation
219        return None
220
221    def log_details(self, msg):
222        """Collect the detailed messages to be printed within this module."""
223        prefix_space = ' ' * 4
224        formatted_msg = '%s%s' % (prefix_space, msg)
225        self.vlog.insert_details(formatted_msg)
226
227    def get_threshold(self, criteria_str, op):
228        """Search the criteria_str using regular expressions and get
229        the threshold value.
230
231        @param criteria_str: the criteria string to search
232        """
233        # In the search pattern, '.*?' is non-greedy, which will match as
234        # few characters as possible.
235        #   E.g., op = '>'
236        #         criteria_str = '>= 200, ~ -100'
237        #         pattern below would be '>.*?\s*(\d+)'
238        #         result.group(1) below would be '200'
239        pattern = '{}.*?\s*(\d+)'.format(op)
240        result = re.search(pattern, criteria_str)
241        return int(result.group(1)) if result else None
242
243    def _get_axes_by_finger(self, finger):
244        """Get list_x, list_y, and list_t for the specified finger.
245
246        @param finger: the finger contact
247        """
248        points = self.packets.get_ordered_finger_path(self.finger, 'point')
249        list_x = [p.x for p in points]
250        list_y = [p.y for p in points]
251        list_t = self.packets.get_ordered_finger_path(self.finger, 'syn_time')
252        return (list_x, list_y, list_t)
253
254
255class LinearityValidator1(BaseValidator):
256    """Validator to verify linearity.
257
258    Example:
259        To check the linearity of the line drawn in finger 1:
260          LinearityValidator1('<= 0.03, ~ +0.07', finger=1)
261    """
262    # Define the partial group size for calculating Mean Squared Error
263    MSE_PARTIAL_GROUP_SIZE = 1
264
265    def __init__(self, criteria_str, mf=None, device=None, finger=0,
266                 segments=VAL.WHOLE):
267        self._segments = segments
268        self.finger = finger
269        name = get_derived_name(self.__class__.__name__, segments)
270        super(LinearityValidator1, self).__init__(criteria_str, mf, device,
271                                                  name)
272
273    def _simple_linear_regression(self, ax, ay):
274        """Calculate the simple linear regression and returns the
275           sum of squared residuals.
276
277        It calculates the simple linear regression line for the points
278        in the middle segment of the line. This exclude the points at
279        both ends of the line which sometimes have wobbles. Then it
280        calculates the fitting errors of the points at the specified segments
281        against the computed simple linear regression line.
282        """
283        # Compute the simple linear regression line for the middle segment
284        # whose purpose is to avoid wobbles on both ends of the line.
285        mid_segment = self.packets.get_segments_x_and_y(ax, ay, VAL.MIDDLE,
286                                                        END_PERCENTAGE)
287        if not self._calc_simple_linear_regression_line(*mid_segment):
288            return 0
289
290        # Compute the fitting errors of the specified segments.
291        if self._segments == VAL.BOTH_ENDS:
292            bgn_segment = self.packets.get_segments_x_and_y(ax, ay, VAL.BEGIN,
293                                                            END_PERCENTAGE)
294            end_segment = self.packets.get_segments_x_and_y(ax, ay, VAL.END,
295                                                            END_PERCENTAGE)
296            bgn_error = self._calc_simple_linear_regression_error(*bgn_segment)
297            end_error = self._calc_simple_linear_regression_error(*end_segment)
298            return max(bgn_error, end_error)
299        else:
300            target_segment = self.packets.get_segments_x_and_y(ax, ay,
301                    self._segments, END_PERCENTAGE)
302            return self._calc_simple_linear_regression_error(*target_segment)
303
304    def _calc_simple_linear_regression_line(self, ax, ay):
305        """Calculate the simple linear regression line.
306
307           ax: array x
308           ay: array y
309           This method tries to find alpha and beta in the formula
310                ay = alpha + beta . ax
311           such that it has the least sum of squared residuals.
312
313           Reference:
314           - Simple linear regression:
315             http://en.wikipedia.org/wiki/Simple_linear_regression
316           - Average absolute deviation (or mean absolute deviation) :
317             http://en.wikipedia.org/wiki/Average_absolute_deviation
318        """
319        # Convert the int list to the float array
320        self._ax = 1.0 * np.array(ax)
321        self._ay = 1.0 * np.array(ay)
322
323        # If there are less than 2 data points, it is not a line at all.
324        asize = self._ax.size
325        if asize <= 2:
326            return False
327
328        Sx = self._ax.sum()
329        Sy = self._ay.sum()
330        Sxx = np.square(self._ax).sum()
331        Sxy = np.dot(self._ax, self._ay)
332        Syy = np.square(self._ay).sum()
333        Sx2 = Sx * Sx
334        Sy2 = Sy * Sy
335
336        # compute Mean of x and y
337        Mx = self._ax.mean()
338        My = self._ay.mean()
339
340        # Compute beta and alpha of the linear regression
341        self._beta = 1.0 * (asize * Sxy - Sx * Sy) / (asize * Sxx - Sx2)
342        self._alpha = My - self._beta * Mx
343        return True
344
345    def _calc_simple_linear_regression_error(self, ax, ay):
346        """Calculate the fitting error based on the simple linear regression
347        line characterized by the equation parameters alpha and beta.
348        """
349        # Convert the int list to the float array
350        ax = 1.0 * np.array(ax)
351        ay = 1.0 * np.array(ay)
352
353        asize = ax.size
354        partial = min(asize, max(1, self.MSE_PARTIAL_GROUP_SIZE))
355
356        # spmse: squared root of partial mean squared error
357        spmse = np.square(ay - self._alpha - self._beta * ax)
358        spmse.sort()
359        spmse = spmse[asize - partial : asize]
360        spmse = np.sqrt(np.average(spmse))
361        return spmse
362
363    def check(self, packets, variation=None):
364        """Check if the packets conforms to specified criteria."""
365        self.init_check(packets)
366        resolution_x, resolution_y = self.device.get_resolutions()
367        (list_x, list_y) = self.packets.get_x_y(self.finger)
368        # Compute average distance (fitting error) in pixels, and
369        # average deviation on touch device in mm.
370        if self.is_vertical(variation):
371            ave_distance = self._simple_linear_regression(list_y, list_x)
372            deviation = ave_distance / resolution_x
373        else:
374            ave_distance = self._simple_linear_regression(list_x, list_y)
375            deviation = ave_distance / resolution_y
376
377        self.log_details('ave fitting error: %.2f px' % ave_distance)
378        msg_device = 'deviation finger%d: %.2f mm'
379        self.log_details(msg_device % (self.finger, deviation))
380        self.vlog.score = self.fc.mf.grade(deviation)
381        return self.vlog
382
383
384class LinearityValidator(BaseValidator):
385    """A validator to verify linearity based on x-t and y-t
386
387    Example:
388        To check the linearity of the line drawn in finger 1:
389          LinearityValidator('<= 0.03, ~ +0.07', finger=1)
390        Note: the finger number begins from 0
391    """
392    # Define the partial group size for calculating Mean Squared Error
393    MSE_PARTIAL_GROUP_SIZE = 1
394
395    def __init__(self, criteria_str, mf=None, device=None, finger=0,
396                 segments=VAL.WHOLE):
397        self._segments = segments
398        self.finger = finger
399        name = get_derived_name(self.__class__.__name__, segments)
400        super(LinearityValidator, self).__init__(criteria_str, mf, device,
401                                                  name)
402
403    def _calc_residuals(self, line, list_t, list_y):
404        """Calculate the residuals of the points in list_t, list_y against
405        the line.
406
407        @param line: the regression line of list_t and list_y
408        @param list_t: a list of time instants
409        @param list_y: a list of x/y coordinates
410
411        This method returns the list of residuals, where
412            residual[i] = line[t_i] - y_i
413        where t_i is an element in list_t and
414              y_i is a corresponding element in list_y.
415
416        We calculate the vertical distance (y distance) here because the
417        horizontal axis, list_t, always represent the time instants, and the
418        vertical axis, list_y, could be either the coordinates in x or y axis.
419        """
420        return [float(line(t) - y) for t, y in zip(list_t, list_y)]
421
422    def _do_simple_linear_regression(self, list_t, list_y):
423        """Calculate the simple linear regression line and returns the
424        sum of squared residuals.
425
426        @param list_t: the list of time instants
427        @param list_y: the list of x or y coordinates of touch contacts
428
429        It calculates the residuals (fitting errors) of the points at the
430        specified segments against the computed simple linear regression line.
431
432        Reference:
433        - Simple linear regression:
434          http://en.wikipedia.org/wiki/Simple_linear_regression
435        - numpy.polyfit(): used to calculate the simple linear regression line.
436          http://docs.scipy.org/doc/numpy/reference/generated/numpy.polyfit.html
437        """
438        # At least 2 points to determine a line.
439        if len(list_t) < 2 or len(list_y) < 2:
440            return []
441
442        mid_segment_t, mid_segment_y = self.packets.get_segments(
443                list_t, list_y, VAL.MIDDLE, END_PERCENTAGE)
444
445        # Calculate the simple linear regression line.
446        degree = 1
447        regress_line = np.poly1d(np.polyfit(mid_segment_t, mid_segment_y,
448                                            degree))
449
450        # Compute the fitting errors of the specified segments.
451        if self._segments == VAL.BOTH_ENDS:
452            begin_segments = self.packets.get_segments(
453                    list_t, list_y, VAL.BEGIN, END_PERCENTAGE)
454            end_segments = self.packets.get_segments(
455                    list_t, list_y, VAL.END, END_PERCENTAGE)
456            begin_error = self._calc_residuals(regress_line, *begin_segments)
457            end_error = self._calc_residuals(regress_line, *end_segments)
458            return begin_error + end_error
459        else:
460            target_segments = self.packets.get_segments(
461                    list_t, list_y, self._segments, END_PERCENTAGE)
462            return self._calc_residuals(regress_line, *target_segments)
463
464    def _calc_errors_single_axis(self, list_t, list_y):
465        """Calculate various errors for axis-time.
466
467        @param list_t: the list of time instants
468        @param list_y: the list of x or y coordinates of touch contacts
469        """
470        # It is fine if axis-time is a horizontal line.
471        errors_px = self._do_simple_linear_regression(list_t, list_y)
472        if not errors_px:
473            return (0, 0)
474
475        # Calculate the max errors
476        max_err_px = max(map(abs, errors_px))
477
478        # Calculate the root mean square errors
479        e2 = [e * e for e in errors_px]
480        rms_err_px = (float(sum(e2)) / len(e2)) ** 0.5
481
482        return (max_err_px, rms_err_px)
483
484    def _calc_errors_all_axes(self, list_t, list_x, list_y):
485        """Calculate various errors for all axes."""
486        # Calculate max error and average squared error
487        (max_err_x_px, rms_err_x_px) = self._calc_errors_single_axis(
488                list_t, list_x)
489        (max_err_y_px, rms_err_y_px) = self._calc_errors_single_axis(
490                list_t, list_y)
491
492        # Convert the unit from pixels to mms
493        self.max_err_x_mm, self.max_err_y_mm = self.device.pixel_to_mm(
494                (max_err_x_px, max_err_y_px))
495        self.rms_err_x_mm, self.rms_err_y_mm = self.device.pixel_to_mm(
496                (rms_err_x_px, rms_err_y_px))
497
498    def _log_details_and_metrics(self, variation):
499        """Log the details and calculate the metrics.
500
501        @param variation: the gesture variation
502        """
503        list_x, list_y, list_t = self._get_axes_by_finger(self.finger)
504        X, Y = AXIS.LIST
505        # For horizontal lines, only consider x axis
506        if self.is_horizontal(variation):
507            self.list_coords = {X: list_x}
508        # For vertical lines, only consider y axis
509        elif self.is_vertical(variation):
510            self.list_coords = {Y: list_y}
511        # For diagonal lines, consider both x and y axes
512        elif self.is_diagonal(variation):
513            self.list_coords = {X: list_x, Y: list_y}
514
515        self.max_err_mm = {}
516        self.rms_err_mm = {}
517        self.vlog.metrics = []
518        mnprops = self.mnprops
519        pixel_to_mm = self.device.pixel_to_mm_single_axis_by_name
520        for axis, list_c in self.list_coords.items():
521            max_err_px, rms_err_px = self._calc_errors_single_axis(
522                    list_t, list_c)
523            max_err_mm = pixel_to_mm(max_err_px, axis)
524            rms_err_mm = pixel_to_mm(rms_err_px, axis)
525            self.log_details('max_err[%s]: %.2f mm' % (axis, max_err_mm))
526            self.log_details('rms_err[%s]: %.2f mm' % (axis, rms_err_mm))
527            self.vlog.metrics.extend([
528                firmware_log.Metric(mnprops.MAX_ERR.format(axis), max_err_mm),
529                firmware_log.Metric(mnprops.RMS_ERR.format(axis), rms_err_mm),
530            ])
531            self.max_err_mm[axis] = max_err_mm
532            self.rms_err_mm[axis] = rms_err_mm
533
534    def check(self, packets, variation=None):
535        """Check if the packets conforms to specified criteria."""
536        self.init_check(packets)
537        self._log_details_and_metrics(variation)
538        # Calculate the score based on the max error
539        max_err = max(self.max_err_mm.values())
540        self.vlog.score = self.fc.mf.grade(max_err)
541        return self.vlog
542
543
544class RangeValidator(BaseValidator):
545    """Validator to check the observed (x, y) positions should be within
546    the range of reported min/max values.
547
548    Example:
549        To check the range of observed edge-to-edge positions:
550          RangeValidator('<= 0.05, ~ +0.05')
551    """
552
553    def __init__(self, criteria_str, mf=None, device=None):
554        self.name = self.__class__.__name__
555        super(RangeValidator, self).__init__(criteria_str, mf, device,
556                                             self.name)
557
558    def check(self, packets, variation=None):
559        """Check the left/right or top/bottom range based on the direction."""
560        self.init_check(packets)
561        valid_directions = [GV.CL, GV.CR, GV.CT, GV.CB]
562        Range = namedtuple('Range', valid_directions)
563        actual_range = Range(*self.packets.get_range())
564        spec_range = Range(self.device.axis_x.min, self.device.axis_x.max,
565                           self.device.axis_y.min, self.device.axis_y.max)
566
567        direction = self.get_direction_in_variation(variation)
568        if direction in valid_directions:
569            actual_edge = getattr(actual_range, direction)
570            spec_edge = getattr(spec_range, direction)
571            short_of_range_px = abs(actual_edge - spec_edge)
572        else:
573            err_msg = 'Error: the gesture variation %s is not allowed in %s.'
574            print_and_exit(err_msg % (variation, self.name))
575
576        axis_spec = (self.device.axis_x if self.is_horizontal(variation)
577                                        else self.device.axis_y)
578        deviation_ratio = (float(short_of_range_px) /
579                           (axis_spec.max - axis_spec.min))
580        # Convert the direction to edge name.
581        #   E.g., direction: center_to_left
582        #         edge name: left
583        edge_name = direction.split('_')[-1]
584        metric_name = self.mnprops.RANGE.format(edge_name)
585        short_of_range_mm = self.device.pixel_to_mm_single_axis(
586                short_of_range_px, axis_spec)
587        self.vlog.metrics = [
588            firmware_log.Metric(metric_name, short_of_range_mm)
589        ]
590        self.log_details('actual: px %s' % str(actual_edge))
591        self.log_details('spec: px %s' % str(spec_edge))
592        self.log_details('short of range: %d px == %f mm' %
593                         (short_of_range_px, short_of_range_mm))
594        self.vlog.score = self.fc.mf.grade(deviation_ratio)
595        return self.vlog
596
597
598class CountTrackingIDValidator(BaseValidator):
599    """Validator to check the count of tracking IDs.
600
601    Example:
602        To verify if there is exactly one finger observed:
603          CountTrackingIDValidator('== 1')
604    """
605
606    def __init__(self, criteria_str, mf=None, device=None):
607        name = self.__class__.__name__
608        super(CountTrackingIDValidator, self).__init__(criteria_str, mf,
609                                                       device, name)
610
611    def check(self, packets, variation=None):
612        """Check the number of tracking IDs observed."""
613        self.init_check(packets)
614
615        # Get the actual count of tracking id and log the details.
616        actual_count_tid = self.packets.get_number_contacts()
617        self.log_details('count of trackid IDs: %d' % actual_count_tid)
618
619        # Only keep metrics with the criteria '== N'.
620        # Ignore those with '>= N' which are used to assert that users have
621        # performed correct gestures. As an example, we require that users
622        # tap more than a certain number of times in the drumroll test.
623        if '==' in self.criteria_str:
624            expected_count_tid = int(self.criteria_str.split('==')[-1].strip())
625            # E.g., expected_count_tid = 2
626            #       actual_count_tid could be either smaller (e.g., 1) or
627            #       larger (e.g., 3).
628            metric_value = (actual_count_tid, expected_count_tid)
629            metric_name = self.mnprops.TID
630            self.vlog.metrics = [firmware_log.Metric(metric_name, metric_value)]
631
632        self.vlog.score = self.fc.mf.grade(actual_count_tid)
633        return self.vlog
634
635
636class StationaryFingerValidator(BaseValidator):
637    """Validator to check the count of tracking IDs.
638
639    Example:
640        To verify if the stationary finger specified by the slot does not
641        move larger than a specified radius:
642          StationaryFingerValidator('<= 15 ~ +10')
643    """
644
645    def __init__(self, criteria, mf=None, device=None, slot=0):
646        name = self.__class__.__name__
647        super(StationaryFingerValidator, self).__init__(criteria, mf,
648                                                        device, name)
649        self.slot = slot
650
651    def check(self, packets, variation=None):
652        """Check the moving distance of the specified finger."""
653        self.init_check(packets)
654        max_distance = self.packets.get_max_distance(self.slot, UNIT.MM)
655        msg = 'Max distance slot%d: %d mm'
656        self.log_details(msg % (self.slot, max_distance))
657        self.vlog.metrics = [
658            firmware_log.Metric(self.mnprops.MAX_DISTANCE, max_distance)
659        ]
660        self.vlog.score = self.fc.mf.grade(max_distance)
661        return self.vlog
662
663
664class NoGapValidator(BaseValidator):
665    """Validator to make sure that there are no significant gaps in a line.
666
667    Example:
668        To verify if there is exactly one finger observed:
669          NoGapValidator('<= 5, ~ +5', slot=1)
670    """
671
672    def __init__(self, criteria_str, mf=None, device=None, slot=0):
673        name = self.__class__.__name__
674        super(NoGapValidator, self).__init__(criteria_str, mf, device, name)
675        self.slot = slot
676
677    def check(self, packets, variation=None):
678        """There should be no significant gaps in a line."""
679        self.init_check(packets)
680        # Get the largest gap ratio
681        gap_ratio = self.packets.get_largest_gap_ratio(self.slot)
682        msg = 'Largest gap ratio slot%d: %f'
683        self.log_details(msg % (self.slot, gap_ratio))
684        self.vlog.score = self.fc.mf.grade(gap_ratio)
685        return self.vlog
686
687
688class NoReversedMotionValidator(BaseValidator):
689    """Validator to measure the reversed motions in the specified slots.
690
691    Example:
692        To measure the reversed motions in slot 0:
693          NoReversedMotionValidator('== 0, ~ +20', slots=0)
694    """
695    def __init__(self, criteria_str, mf=None, device=None, slots=(0,),
696                 segments=VAL.MIDDLE):
697        self._segments = segments
698        name = get_derived_name(self.__class__.__name__, segments)
699        self.slots = (slots,) if isinstance(slots, int) else slots
700        parent = super(NoReversedMotionValidator, self)
701        parent.__init__(criteria_str, mf, device, name)
702
703    def _get_reversed_motions(self, slot, direction):
704        """Get the reversed motions opposed to the direction in the slot."""
705        return self.packets.get_reversed_motions(slot,
706                                                 direction,
707                                                 segment_flag=self._segments,
708                                                 ratio=END_PERCENTAGE)
709
710    def check(self, packets, variation=None):
711        """There should be no reversed motions in a slot."""
712        self.init_check(packets)
713        sum_reversed_motions = 0
714        direction = self.get_direction_in_variation(variation)
715        for slot in self.slots:
716            # Get the reversed motions.
717            reversed_motions = self._get_reversed_motions(slot, direction)
718            msg = 'Reversed motions slot%d: %s px'
719            self.log_details(msg % (slot, reversed_motions))
720            sum_reversed_motions += sum(map(abs, reversed_motions.values()))
721        self.vlog.score = self.fc.mf.grade(sum_reversed_motions)
722        return self.vlog
723
724
725class CountPacketsValidator(BaseValidator):
726    """Validator to check the number of packets.
727
728    Example:
729        To verify if there are enough packets received about the first finger:
730          CountPacketsValidator('>= 3, ~ -3', slot=0)
731    """
732
733    def __init__(self, criteria_str, mf=None, device=None, slot=0):
734        self.name = self.__class__.__name__
735        super(CountPacketsValidator, self).__init__(criteria_str, mf, device,
736                                                    self.name)
737        self.slot = slot
738
739    def check(self, packets, variation=None):
740        """Check the number of packets in the specified slot."""
741        self.init_check(packets)
742        # Get the number of packets in that slot
743        actual_count_packets = self.packets.get_num_packets(self.slot)
744        msg = 'Number of packets slot%d: %s'
745        self.log_details(msg % (self.slot, actual_count_packets))
746
747        # Add the metric for the count of packets
748        expected_count_packets = self.get_threshold(self.criteria_str, '>')
749        assert expected_count_packets, 'Check the criteria of %s' % self.name
750        metric_value = (actual_count_packets, expected_count_packets)
751        metric_name = self.mnprops.COUNT_PACKETS
752        self.vlog.metrics = [firmware_log.Metric(metric_name, metric_value)]
753
754        self.vlog.score = self.fc.mf.grade(actual_count_packets)
755        return self.vlog
756
757
758class PinchValidator(BaseValidator):
759    """Validator to check the pinch to zoom in/out.
760
761    Example:
762        To verify that the two fingers are drawing closer:
763          PinchValidator('>= 200, ~ -100')
764    """
765
766    def __init__(self, criteria_str, mf=None, device=None):
767        self.name = self.__class__.__name__
768        super(PinchValidator, self).__init__(criteria_str, mf, device,
769                                             self.name)
770
771    def check(self, packets, variation):
772        """Check the number of packets in the specified slot."""
773        self.init_check(packets)
774        # Get the relative motion of the two fingers
775        slots = (0, 1)
776        actual_relative_motion = self.packets.get_relative_motion(slots)
777        if variation == GV.ZOOM_OUT:
778            actual_relative_motion = -actual_relative_motion
779        msg = 'Relative motions of the two fingers: %.2f px'
780        self.log_details(msg % actual_relative_motion)
781
782        # Add the metric for relative motion distance.
783        expected_relative_motion = self.get_threshold(self.criteria_str, '>')
784        assert expected_relative_motion, 'Check the criteria of %s' % self.name
785        metric_value = (actual_relative_motion, expected_relative_motion)
786        metric_name = self.mnprops.PINCH
787        self.vlog.metrics = [firmware_log.Metric(metric_name, metric_value)]
788
789        self.vlog.score = self.fc.mf.grade(actual_relative_motion)
790        return self.vlog
791
792
793class PhysicalClickValidator(BaseValidator):
794    """Validator to check the events generated by physical clicks
795
796    Example:
797        To verify the events generated by a one-finger physical click
798          PhysicalClickValidator('== 1', fingers=1)
799    """
800
801    def __init__(self, criteria_str, fingers, mf=None, device=None):
802        self.criteria_str = criteria_str
803        self.name = self.__class__.__name__
804        super(PhysicalClickValidator, self).__init__(criteria_str, mf, device,
805                                                     self.name)
806        self.fingers = fingers
807
808    def _get_expected_number(self):
809        """Get the expected number of counts from the criteria string.
810
811        E.g., criteria_str: '== 1'
812        """
813        try:
814            expected_count = int(self.criteria_str.split('==')[-1].strip())
815        except Exception, e:
816            print 'Error: %s in the criteria string of %s' % (e, self.name)
817            exit(1)
818        return expected_count
819
820    def _add_metrics(self):
821        """Add metrics"""
822        fingers = self.fingers
823        raw_click_count = self.packets.get_raw_physical_clicks()
824
825        # This is for the metric:
826        #   "of the n clicks, the % of clicks with the correct finger IDs"
827        correct_click_count = self.packets.get_correct_physical_clicks(fingers)
828        value_with_TIDs = (correct_click_count, raw_click_count)
829        name_with_TIDs = self.mnprops.CLICK_CHECK_TIDS.format(self.fingers)
830
831        # This is for the metric: "% of finger IDs with a click"
832        expected_click_count = self._get_expected_number()
833        value_clicks = (raw_click_count, expected_click_count)
834        name_clicks = self.mnprops.CLICK_CHECK_CLICK.format(self.fingers)
835
836        self.vlog.metrics = [
837            firmware_log.Metric(name_with_TIDs, value_with_TIDs),
838            firmware_log.Metric(name_clicks, value_clicks),
839        ]
840
841        return value_with_TIDs
842
843    def check(self, packets, variation=None):
844        """Check the number of packets in the specified slot."""
845        self.init_check(packets)
846        correct_click_count, raw_click_count = self._add_metrics()
847        # Get the number of physical clicks made with the specified number
848        # of fingers.
849        msg = 'Count of %d-finger physical clicks: %s'
850        self.log_details(msg % (self.fingers, correct_click_count))
851        self.log_details('Count of physical clicks: %d' % raw_click_count)
852        self.vlog.score = self.fc.mf.grade(correct_click_count)
853        return self.vlog
854
855
856class DrumrollValidator(BaseValidator):
857    """Validator to check the drumroll problem.
858
859    All points from the same finger should be within 2 circles of radius X mm
860    (e.g. 2 mm)
861
862    Example:
863        To verify that the max radius of all minimal enclosing circles generated
864        by alternately tapping the index and middle fingers is within 2.0 mm.
865          DrumrollValidator('<= 2.0')
866    """
867
868    def __init__(self, criteria_str, mf=None, device=None):
869        name = self.__class__.__name__
870        super(DrumrollValidator, self).__init__(criteria_str, mf, device, name)
871
872    def check(self, packets, variation=None):
873        """The moving distance of the points in any tracking ID should be
874        within the specified value.
875        """
876        self.init_check(packets)
877        # For each tracking ID, compute the minimal enclosing circles,
878        #     rocs = (radius_of_circle1, radius_of_circle2)
879        # Return a list of such minimal enclosing circles of all tracking IDs.
880        rocs = self.packets.get_list_of_rocs_of_all_tracking_ids()
881        max_radius = max(rocs)
882        self.log_details('Max radius: %.2f mm' % max_radius)
883        metric_name = self.mnprops.CIRCLE_RADIUS
884        self.vlog.metrics = [firmware_log.Metric(metric_name, roc)
885                             for roc in rocs]
886        self.vlog.score = self.fc.mf.grade(max_radius)
887        return self.vlog
888
889
890class NoLevelJumpValidator(BaseValidator):
891    """Validator to check if there are level jumps
892
893    When a user draws a horizontal line with thumb edge or a fat finger,
894    the line could comprise a horizontal line segment followed by another
895    horizontal line segment (or just dots) one level up or down, and then
896    another horizontal line segment again at different horizontal level, etc.
897    This validator is implemented to detect such level jumps.
898
899    Such level jumps could also occur when drawing vertical or diagonal lines.
900
901    Example:
902        To verify the level jumps in a one-finger tracking gesture:
903          NoLevelJumpValidator('<= 10, ~ +30', slots[0,])
904        where slots[0,] represent the slots with numbers larger than slot 0.
905        This kind of representation is required because when the thumb edge or
906        a fat finger is used, due to the difficulty in handling it correctly
907        in the touch device firmware, the tracking IDs and slot IDs may keep
908        changing. We would like to analyze all such slots.
909    """
910
911    def __init__(self, criteria_str, mf=None, device=None, slots=0):
912        name = self.__class__.__name__
913        super(NoLevelJumpValidator, self).__init__(criteria_str, mf, device,
914                                                   name)
915        self.slots = slots
916
917    def check(self, packets, variation=None):
918        """Check if there are level jumps."""
919        self.init_check(packets)
920        # Get the displacements of the slots.
921        slots = self.slots[0]
922        displacements = self.packets.get_displacements_for_slots(slots)
923
924        # Iterate through the collected tracking IDs
925        jumps = []
926        for tid in displacements:
927            slot = displacements[tid][MTB.SLOT]
928            for axis in AXIS.LIST:
929                disp = displacements[tid][axis]
930                jump = self.packets.get_largest_accumulated_level_jumps(disp)
931                jumps.append(jump)
932                msg = '  accu jump (%d %s): %d px'
933                self.log_details(msg % (slot, axis, jump))
934
935        # Get the largest accumulated level jump
936        max_jump = max(jumps) if jumps else 0
937        msg = 'Max accu jump: %d px'
938        self.log_details(msg % (max_jump))
939        self.vlog.score = self.fc.mf.grade(max_jump)
940        return self.vlog
941
942
943class ReportRateValidator(BaseValidator):
944    """Validator to check the report rate.
945
946    Example:
947        To verify that the report rate is around 80 Hz. It gets 0 points
948        if the report rate drops below 60 Hz.
949          ReportRateValidator('== 80 ~ -20')
950    """
951
952    def __init__(self, criteria_str, finger=None, mf=None, device=None,
953                 chop_off_pauses=True):
954        """Initialize ReportRateValidator
955
956        @param criteria_str: the criteria string
957        @param finger: the ith contact if not None. When set to None, it means
958                to examine all packets.
959        @param mf: the fuzzy member function to use
960        @param device: the touch device
961        """
962        self.name = self.__class__.__name__
963        self.criteria_str = criteria_str
964        self.finger = finger
965        if finger is not None:
966            msg = '%s: finger = %d (It is required that finger >= 0.)'
967            assert finger >= 0, msg % (self.name, finger)
968        self.chop_off_pauses = chop_off_pauses
969        super(ReportRateValidator, self).__init__(criteria_str, mf, device,
970                                                  self.name)
971
972    def _chop_off_both_ends(self, points, distance):
973        """Chop off both ends of segments such that the points in the remaining
974        middle segment are distant from both ends by more than the specified
975        distance.
976
977        When performing a gesture such as finger tracking, it is possible
978        that the finger will stay stationary for a while before it actually
979        starts moving. Likewise, it is also possible that the finger may stay
980        stationary before the finger leaves the touch surface. We would like
981        to chop off the stationary segments.
982
983        Note: if distance is 0, the effect is equivalent to keep all points.
984
985        @param points: a list of Points
986        @param distance: the distance within which the points are chopped off
987        """
988        def _find_index(points, distance, reversed_flag=False):
989            """Find the first index of the point whose distance with the
990            first point is larger than the specified distance.
991
992            @param points: a list of Points
993            @param distance: the distance
994            @param reversed_flag: indicates if the points needs to be reversed
995            """
996            points_len = len(points)
997            if reversed_flag:
998                points = reversed(points)
999
1000            ref_point = None
1001            for i, p in enumerate(points):
1002                if ref_point is None:
1003                    ref_point = p
1004                if ref_point.distance(p) >= distance:
1005                    return (points_len - i - 1) if reversed_flag else i
1006
1007            return None
1008
1009        # There must be extra points in addition to the first and the last point
1010        if len(points) <= 2:
1011            return None
1012
1013        begin_moving_index = _find_index(points, distance, reversed_flag=False)
1014        end_moving_index = _find_index(points, distance, reversed_flag=True)
1015
1016        if (begin_moving_index is None or end_moving_index is None or
1017                begin_moving_index > end_moving_index):
1018            return None
1019        return [begin_moving_index, end_moving_index]
1020
1021    def _add_report_rate_metrics2(self):
1022        """Calculate and add the metrics about report rate.
1023
1024        Three metrics are required.
1025        - % of time intervals that are > (1/60) second
1026        - average time interval
1027        - max time interval
1028
1029        """
1030        import test_conf as conf
1031
1032        if self.finger:
1033            finger_list = [self.finger]
1034        else:
1035            ordered_finger_paths_dict = self.packets.get_ordered_finger_paths()
1036            finger_list = range(len(ordered_finger_paths_dict))
1037
1038        # distance: the minimal moving distance within which the points
1039        #           at both ends will be chopped off
1040        distance = conf.MIN_MOVING_DISTANCE if self.chop_off_pauses else 0
1041
1042        # Derive the middle moving segment in which the finger(s)
1043        # moves significantly.
1044        begin_time = float('infinity')
1045        end_time = float('-infinity')
1046        for finger in finger_list:
1047            list_t = self.packets.get_ordered_finger_path(finger, 'syn_time')
1048            points = self.packets.get_ordered_finger_path(finger, 'point')
1049            middle = self._chop_off_both_ends(points, distance)
1050            if middle:
1051                this_begin_index, this_end_index = middle
1052                this_begin_time = list_t[this_begin_index]
1053                this_end_time = list_t[this_end_index]
1054                begin_time = min(begin_time, this_begin_time)
1055                end_time = max(end_time, this_end_time)
1056
1057        if (begin_time == float('infinity') or end_time == float('-infinity')
1058                or end_time <= begin_time):
1059            print 'Warning: %s: cannot derive a moving segment.' % self.name
1060            print 'begin_time: ', begin_time
1061            print 'end_time: ', end_time
1062            return
1063
1064        # Get the list of SYN_REPORT time in the middle moving segment.
1065        list_syn_time = filter(lambda t: t >= begin_time and t <= end_time,
1066                               self.packets.get_list_syn_time(self.finger))
1067
1068        # Each packet consists of a list of events of which The last one is
1069        # the sync event. The unit of sync_intervals is ms.
1070        sync_intervals = [1000.0 * (list_syn_time[i + 1] - list_syn_time[i])
1071                          for i in range(len(list_syn_time) - 1)]
1072
1073        max_report_interval = conf.max_report_interval
1074
1075        # Calculate the metrics and add them to vlog.
1076        long_intervals = [s for s in sync_intervals if s > max_report_interval]
1077        metric_long_intervals = (len(long_intervals), len(sync_intervals))
1078        ave_interval = sum(sync_intervals) / len(sync_intervals)
1079        max_interval = max(sync_intervals)
1080
1081        name_long_intervals_pct = self.mnprops.LONG_INTERVALS.format(
1082            '%.2f' % max_report_interval)
1083        name_ave_time_interval = self.mnprops.AVE_TIME_INTERVAL
1084        name_max_time_interval = self.mnprops.MAX_TIME_INTERVAL
1085
1086        self.vlog.metrics = [
1087            firmware_log.Metric(name_long_intervals_pct, metric_long_intervals),
1088            firmware_log.Metric(self.mnprops.AVE_TIME_INTERVAL, ave_interval),
1089            firmware_log.Metric(self.mnprops.MAX_TIME_INTERVAL, max_interval),
1090        ]
1091
1092        self.log_details('%s: %f' % (self.mnprops.AVE_TIME_INTERVAL,
1093                         ave_interval))
1094        self.log_details('%s: %f' % (self.mnprops.MAX_TIME_INTERVAL,
1095                         max_interval))
1096        self.log_details('# long intervals > %s ms: %d' %
1097                         (self.mnprops.max_report_interval_str,
1098                          len(long_intervals)))
1099        self.log_details('# total intervals: %d' % len(sync_intervals))
1100
1101    def _get_report_rate(self, list_syn_time):
1102        """Get the report rate in Hz from the list of syn_time.
1103
1104        @param list_syn_time: a list of SYN_REPORT time instants
1105        """
1106        if len(list_syn_time) <= 1:
1107            return 0
1108        duration = list_syn_time[-1] - list_syn_time[0]
1109        num_packets = len(list_syn_time) - 1
1110        report_rate = float(num_packets) / duration
1111        return report_rate
1112
1113    def check(self, packets, variation=None):
1114        """The Report rate should be within the specified range."""
1115        self.init_check(packets)
1116        # Get the list of syn_time based on the specified finger.
1117        list_syn_time = self.packets.get_list_syn_time(self.finger)
1118        # Get the report rate
1119        self.report_rate = self._get_report_rate(list_syn_time)
1120        self._add_report_rate_metrics2()
1121        self.vlog.score = self.fc.mf.grade(self.report_rate)
1122        return self.vlog
1123