validators.py revision d1603ecc4264a1c789e0de7e2c584ab83235a068
1# Copyright (c) 2012 The Chromium OS Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5"""Validators to verify if events conform to specified criteria.""" 6 7 8''' 9How to add a new validator/gesture: 10(1) Implement a new validator class inheriting BaseValidator, 11(2) add proper method in mtb.Mtb class, 12(3) add the new validator in test_conf, and 13 'from validators import the_new_validator' 14 in alphabetical order, and 15(4) add the validator in relevant gestures; add a new gesture if necessary. 16 17The validator template is as follows: 18 19class XxxValidator(BaseValidator): 20 """Validator to check ... 21 22 Example: 23 To check ... 24 XxxValidator('<= 0.05, ~ +0.05', fingers=2) 25 """ 26 27 def __init__(self, criteria_str, mf=None, fingers=1): 28 name = self.__class__.__name__ 29 super(X..Validator, self).__init__(criteria_str, mf, name) 30 self.fingers = fingers 31 32 def check(self, packets, variation=None): 33 """Check ...""" 34 self.init_check(packets) 35 xxx = self.packets.xxx() 36 self.print_msg(...) 37 return (self.fc.mf.grade(...), self.msg_list) 38 39 40Note that it is also possible to instantiate a validator as 41 XxxValidator('<= 0.05, ~ +0.05', slot=0) 42 43 Difference between fingers and slot: 44 . When specifying 'fingers', e.g., fingers=2, the purpose is to pass 45 the information about how many fingers there are in the gesture. In 46 this case, the events in a specific slot is usually not important. 47 An example is to check how many fingers there are when making a click: 48 PhysicalClickValidator('== 0', fingers=2) 49 . When specifying 'slot', e.g., slot=0, the purpose is pass the slot 50 number to the validator to examine detailed events in that slot. 51 An example of such usage: 52 LinearityValidator('<= 0.03, ~ +0.07', slot=0) 53''' 54 55 56import copy 57import numpy as np 58import os 59import re 60 61import firmware_log 62import fuzzy 63import mtb 64 65from collections import namedtuple 66from inspect import isfunction 67 68from common_util import print_and_exit 69from firmware_constants import AXIS, GV, MTB, UNIT, VAL 70 71 72# Define the ratio of points taken at both ends of a line for edge tests. 73END_PERCENTAGE = 0.1 74 75# Define other constants below. 76VALIDATOR = 'Validator' 77 78 79def validate(packets, gesture, variation): 80 """Validate a single gesture.""" 81 if packets is None: 82 return (None, None) 83 84 msg_list = [] 85 score_list = [] 86 vlogs = [] 87 for validator in gesture.validators: 88 vlog = validator.check(packets, variation) 89 if vlog is None: 90 continue 91 vlogs.append(copy.deepcopy(vlog)) 92 score = vlog.score 93 94 if score is not None: 95 score_list.append(score) 96 # save the validator messages 97 msg_validator_name = '%s' % vlog.name 98 msg_criteria = ' criteria_str: %s' % vlog.criteria 99 msg_score = 'score: %f' % score 100 msg_list.append(os.linesep) 101 msg_list.append(msg_validator_name) 102 msg_list += vlog.details 103 msg_list.append(msg_criteria) 104 msg_list.append(msg_score) 105 106 return (score_list, msg_list, vlogs) 107 108 109def get_short_name(validator_name): 110 """Get the short name of the validator. 111 112 E.g, the short name of LinearityValidator is Linearity. 113 """ 114 return validator_name.split(VALIDATOR)[0] 115 116 117def get_validator_name(short_name): 118 """Convert the short_name to its corresponding validator name. 119 120 E.g, the validator_name of Linearity is LinearityValidator. 121 """ 122 return short_name + VALIDATOR 123 124 125def get_base_name_and_segment(validator_name): 126 """Get the base name and segment of a validator. 127 128 Examples: 129 Ex 1: Linearity(BothEnds)Validator 130 return ('Linearity', 'BothEnds') 131 Ex 2: NoGapValidator 132 return ('NoGap', None) 133 """ 134 if '(' in validator_name: 135 result = re.search('(.*)\((.*)\)%s' % VALIDATOR, validator_name) 136 return (result.group(1), result.group(2)) 137 else: 138 return (get_short_name(validator_name), None) 139 140 141def get_derived_name(validator_name, segment): 142 """Get the derived name based on segment value. 143 144 Example: 145 validator_name: LinearityValidator 146 segment: Middle 147 derived_name: Linearity(Middle)Validator 148 """ 149 short_name = get_short_name(validator_name) 150 derived_name = '%s(%s)%s' % (short_name, segment, VALIDATOR) 151 return derived_name 152 153 154def init_base_validator(device): 155 """Initialize the device for all the Validators to use""" 156 BaseValidator._device = device 157 158 159class BaseValidator(object): 160 """Base class of validators.""" 161 aggregator = 'fuzzy.average' 162 _device = None 163 164 def __init__(self, criteria, mf=None, device=None, name=None): 165 self.criteria_str = criteria() if isfunction(criteria) else criteria 166 self.fc = fuzzy.FuzzyCriteria(self.criteria_str, mf=mf) 167 self.device = device if device else BaseValidator._device 168 self.packets = None 169 self.vlog = firmware_log.ValidatorLog() 170 self.vlog.name = name 171 self.vlog.criteria = self.criteria_str 172 self.mnprops = firmware_log.MetricNameProps() 173 174 def init_check(self, packets=None): 175 """Initialization before check() is called.""" 176 self.packets = mtb.Mtb(device=self.device, packets=packets) 177 self.vlog.reset() 178 179 def _is_direction_in_variation(self, variation, directions): 180 """Is any element of directions list found in variation?""" 181 for direction in directions: 182 if direction in variation: 183 return True 184 return False 185 186 def is_horizontal(self, variation): 187 """Is the direction horizontal?""" 188 return self._is_direction_in_variation(variation, 189 GV.HORIZONTAL_DIRECTIONS) 190 191 def is_vertical(self, variation): 192 """Is the direction vertical?""" 193 return self._is_direction_in_variation(variation, 194 GV.VERTICAL_DIRECTIONS) 195 196 def is_diagonal(self, variation): 197 """Is the direction diagonal?""" 198 return self._is_direction_in_variation(variation, 199 GV.DIAGONAL_DIRECTIONS) 200 201 def get_direction(self, variation): 202 """Get the direction.""" 203 # TODO(josephsih): raise an exception if a proper direction is not found 204 if self.is_horizontal(variation): 205 return GV.HORIZONTAL 206 elif self.is_vertical(variation): 207 return GV.VERTICAL 208 elif self.is_diagonal(variation): 209 return GV.DIAGONAL 210 211 def get_direction_in_variation(self, variation): 212 """Get the direction string from the variation list.""" 213 if isinstance(variation, tuple): 214 for var in variation: 215 if var in GV.GESTURE_DIRECTIONS: 216 return var 217 elif variation in GV.GESTURE_DIRECTIONS: 218 return variation 219 return None 220 221 def log_details(self, msg): 222 """Collect the detailed messages to be printed within this module.""" 223 prefix_space = ' ' * 4 224 formatted_msg = '%s%s' % (prefix_space, msg) 225 self.vlog.insert_details(formatted_msg) 226 227 def get_threshold(self, criteria_str, op): 228 """Search the criteria_str using regular expressions and get 229 the threshold value. 230 231 @param criteria_str: the criteria string to search 232 """ 233 # In the search pattern, '.*?' is non-greedy, which will match as 234 # few characters as possible. 235 # E.g., op = '>' 236 # criteria_str = '>= 200, ~ -100' 237 # pattern below would be '>.*?\s*(\d+)' 238 # result.group(1) below would be '200' 239 pattern = '{}.*?\s*(\d+)'.format(op) 240 result = re.search(pattern, criteria_str) 241 return int(result.group(1)) if result else None 242 243 def _get_axes_by_finger(self, finger): 244 """Get list_x, list_y, and list_t for the specified finger. 245 246 @param finger: the finger contact 247 """ 248 points = self.packets.get_ordered_finger_path(self.finger, 'point') 249 list_x = [p.x for p in points] 250 list_y = [p.y for p in points] 251 list_t = self.packets.get_ordered_finger_path(self.finger, 'syn_time') 252 return (list_x, list_y, list_t) 253 254 255class LinearityValidator1(BaseValidator): 256 """Validator to verify linearity. 257 258 Example: 259 To check the linearity of the line drawn in finger 1: 260 LinearityValidator1('<= 0.03, ~ +0.07', finger=1) 261 """ 262 # Define the partial group size for calculating Mean Squared Error 263 MSE_PARTIAL_GROUP_SIZE = 1 264 265 def __init__(self, criteria_str, mf=None, device=None, finger=0, 266 segments=VAL.WHOLE): 267 self._segments = segments 268 self.finger = finger 269 name = get_derived_name(self.__class__.__name__, segments) 270 super(LinearityValidator1, self).__init__(criteria_str, mf, device, 271 name) 272 273 def _simple_linear_regression(self, ax, ay): 274 """Calculate the simple linear regression and returns the 275 sum of squared residuals. 276 277 It calculates the simple linear regression line for the points 278 in the middle segment of the line. This exclude the points at 279 both ends of the line which sometimes have wobbles. Then it 280 calculates the fitting errors of the points at the specified segments 281 against the computed simple linear regression line. 282 """ 283 # Compute the simple linear regression line for the middle segment 284 # whose purpose is to avoid wobbles on both ends of the line. 285 mid_segment = self.packets.get_segments_x_and_y(ax, ay, VAL.MIDDLE, 286 END_PERCENTAGE) 287 if not self._calc_simple_linear_regression_line(*mid_segment): 288 return 0 289 290 # Compute the fitting errors of the specified segments. 291 if self._segments == VAL.BOTH_ENDS: 292 bgn_segment = self.packets.get_segments_x_and_y(ax, ay, VAL.BEGIN, 293 END_PERCENTAGE) 294 end_segment = self.packets.get_segments_x_and_y(ax, ay, VAL.END, 295 END_PERCENTAGE) 296 bgn_error = self._calc_simple_linear_regression_error(*bgn_segment) 297 end_error = self._calc_simple_linear_regression_error(*end_segment) 298 return max(bgn_error, end_error) 299 else: 300 target_segment = self.packets.get_segments_x_and_y(ax, ay, 301 self._segments, END_PERCENTAGE) 302 return self._calc_simple_linear_regression_error(*target_segment) 303 304 def _calc_simple_linear_regression_line(self, ax, ay): 305 """Calculate the simple linear regression line. 306 307 ax: array x 308 ay: array y 309 This method tries to find alpha and beta in the formula 310 ay = alpha + beta . ax 311 such that it has the least sum of squared residuals. 312 313 Reference: 314 - Simple linear regression: 315 http://en.wikipedia.org/wiki/Simple_linear_regression 316 - Average absolute deviation (or mean absolute deviation) : 317 http://en.wikipedia.org/wiki/Average_absolute_deviation 318 """ 319 # Convert the int list to the float array 320 self._ax = 1.0 * np.array(ax) 321 self._ay = 1.0 * np.array(ay) 322 323 # If there are less than 2 data points, it is not a line at all. 324 asize = self._ax.size 325 if asize <= 2: 326 return False 327 328 Sx = self._ax.sum() 329 Sy = self._ay.sum() 330 Sxx = np.square(self._ax).sum() 331 Sxy = np.dot(self._ax, self._ay) 332 Syy = np.square(self._ay).sum() 333 Sx2 = Sx * Sx 334 Sy2 = Sy * Sy 335 336 # compute Mean of x and y 337 Mx = self._ax.mean() 338 My = self._ay.mean() 339 340 # Compute beta and alpha of the linear regression 341 self._beta = 1.0 * (asize * Sxy - Sx * Sy) / (asize * Sxx - Sx2) 342 self._alpha = My - self._beta * Mx 343 return True 344 345 def _calc_simple_linear_regression_error(self, ax, ay): 346 """Calculate the fitting error based on the simple linear regression 347 line characterized by the equation parameters alpha and beta. 348 """ 349 # Convert the int list to the float array 350 ax = 1.0 * np.array(ax) 351 ay = 1.0 * np.array(ay) 352 353 asize = ax.size 354 partial = min(asize, max(1, self.MSE_PARTIAL_GROUP_SIZE)) 355 356 # spmse: squared root of partial mean squared error 357 spmse = np.square(ay - self._alpha - self._beta * ax) 358 spmse.sort() 359 spmse = spmse[asize - partial : asize] 360 spmse = np.sqrt(np.average(spmse)) 361 return spmse 362 363 def check(self, packets, variation=None): 364 """Check if the packets conforms to specified criteria.""" 365 self.init_check(packets) 366 resolution_x, resolution_y = self.device.get_resolutions() 367 (list_x, list_y) = self.packets.get_x_y(self.finger) 368 # Compute average distance (fitting error) in pixels, and 369 # average deviation on touch device in mm. 370 if self.is_vertical(variation): 371 ave_distance = self._simple_linear_regression(list_y, list_x) 372 deviation = ave_distance / resolution_x 373 else: 374 ave_distance = self._simple_linear_regression(list_x, list_y) 375 deviation = ave_distance / resolution_y 376 377 self.log_details('ave fitting error: %.2f px' % ave_distance) 378 msg_device = 'deviation finger%d: %.2f mm' 379 self.log_details(msg_device % (self.finger, deviation)) 380 self.vlog.score = self.fc.mf.grade(deviation) 381 return self.vlog 382 383 384class LinearityValidator(BaseValidator): 385 """A validator to verify linearity based on x-t and y-t 386 387 Example: 388 To check the linearity of the line drawn in finger 1: 389 LinearityValidator('<= 0.03, ~ +0.07', finger=1) 390 Note: the finger number begins from 0 391 """ 392 # Define the partial group size for calculating Mean Squared Error 393 MSE_PARTIAL_GROUP_SIZE = 1 394 395 def __init__(self, criteria_str, mf=None, device=None, finger=0, 396 segments=VAL.WHOLE): 397 self._segments = segments 398 self.finger = finger 399 name = get_derived_name(self.__class__.__name__, segments) 400 super(LinearityValidator, self).__init__(criteria_str, mf, device, 401 name) 402 403 def _calc_residuals(self, line, list_t, list_y): 404 """Calculate the residuals of the points in list_t, list_y against 405 the line. 406 407 @param line: the regression line of list_t and list_y 408 @param list_t: a list of time instants 409 @param list_y: a list of x/y coordinates 410 411 This method returns the list of residuals, where 412 residual[i] = line[t_i] - y_i 413 where t_i is an element in list_t and 414 y_i is a corresponding element in list_y. 415 416 We calculate the vertical distance (y distance) here because the 417 horizontal axis, list_t, always represent the time instants, and the 418 vertical axis, list_y, could be either the coordinates in x or y axis. 419 """ 420 return [float(line(t) - y) for t, y in zip(list_t, list_y)] 421 422 def _do_simple_linear_regression(self, list_t, list_y): 423 """Calculate the simple linear regression line and returns the 424 sum of squared residuals. 425 426 @param list_t: the list of time instants 427 @param list_y: the list of x or y coordinates of touch contacts 428 429 It calculates the residuals (fitting errors) of the points at the 430 specified segments against the computed simple linear regression line. 431 432 Reference: 433 - Simple linear regression: 434 http://en.wikipedia.org/wiki/Simple_linear_regression 435 - numpy.polyfit(): used to calculate the simple linear regression line. 436 http://docs.scipy.org/doc/numpy/reference/generated/numpy.polyfit.html 437 """ 438 # At least 2 points to determine a line. 439 if len(list_t) < 2 or len(list_y) < 2: 440 return [] 441 442 mid_segment_t, mid_segment_y = self.packets.get_segments( 443 list_t, list_y, VAL.MIDDLE, END_PERCENTAGE) 444 445 # Calculate the simple linear regression line. 446 degree = 1 447 regress_line = np.poly1d(np.polyfit(mid_segment_t, mid_segment_y, 448 degree)) 449 450 # Compute the fitting errors of the specified segments. 451 if self._segments == VAL.BOTH_ENDS: 452 begin_segments = self.packets.get_segments( 453 list_t, list_y, VAL.BEGIN, END_PERCENTAGE) 454 end_segments = self.packets.get_segments( 455 list_t, list_y, VAL.END, END_PERCENTAGE) 456 begin_error = self._calc_residuals(regress_line, *begin_segments) 457 end_error = self._calc_residuals(regress_line, *end_segments) 458 return begin_error + end_error 459 else: 460 target_segments = self.packets.get_segments( 461 list_t, list_y, self._segments, END_PERCENTAGE) 462 return self._calc_residuals(regress_line, *target_segments) 463 464 def _calc_errors_single_axis(self, list_t, list_y): 465 """Calculate various errors for axis-time. 466 467 @param list_t: the list of time instants 468 @param list_y: the list of x or y coordinates of touch contacts 469 """ 470 # It is fine if axis-time is a horizontal line. 471 errors_px = self._do_simple_linear_regression(list_t, list_y) 472 if not errors_px: 473 return (0, 0) 474 475 # Calculate the max errors 476 max_err_px = max(map(abs, errors_px)) 477 478 # Calculate the root mean square errors 479 e2 = [e * e for e in errors_px] 480 rms_err_px = (float(sum(e2)) / len(e2)) ** 0.5 481 482 return (max_err_px, rms_err_px) 483 484 def _calc_errors_all_axes(self, list_t, list_x, list_y): 485 """Calculate various errors for all axes.""" 486 # Calculate max error and average squared error 487 (max_err_x_px, rms_err_x_px) = self._calc_errors_single_axis( 488 list_t, list_x) 489 (max_err_y_px, rms_err_y_px) = self._calc_errors_single_axis( 490 list_t, list_y) 491 492 # Convert the unit from pixels to mms 493 self.max_err_x_mm, self.max_err_y_mm = self.device.pixel_to_mm( 494 (max_err_x_px, max_err_y_px)) 495 self.rms_err_x_mm, self.rms_err_y_mm = self.device.pixel_to_mm( 496 (rms_err_x_px, rms_err_y_px)) 497 498 def _log_details_and_metrics(self, variation): 499 """Log the details and calculate the metrics. 500 501 @param variation: the gesture variation 502 """ 503 list_x, list_y, list_t = self._get_axes_by_finger(self.finger) 504 X, Y = AXIS.LIST 505 # For horizontal lines, only consider x axis 506 if self.is_horizontal(variation): 507 self.list_coords = {X: list_x} 508 # For vertical lines, only consider y axis 509 elif self.is_vertical(variation): 510 self.list_coords = {Y: list_y} 511 # For diagonal lines, consider both x and y axes 512 elif self.is_diagonal(variation): 513 self.list_coords = {X: list_x, Y: list_y} 514 515 self.max_err_mm = {} 516 self.rms_err_mm = {} 517 self.vlog.metrics = [] 518 mnprops = self.mnprops 519 pixel_to_mm = self.device.pixel_to_mm_single_axis_by_name 520 for axis, list_c in self.list_coords.items(): 521 max_err_px, rms_err_px = self._calc_errors_single_axis( 522 list_t, list_c) 523 max_err_mm = pixel_to_mm(max_err_px, axis) 524 rms_err_mm = pixel_to_mm(rms_err_px, axis) 525 self.log_details('max_err[%s]: %.2f mm' % (axis, max_err_mm)) 526 self.log_details('rms_err[%s]: %.2f mm' % (axis, rms_err_mm)) 527 self.vlog.metrics.extend([ 528 firmware_log.Metric(mnprops.MAX_ERR.format(axis), max_err_mm), 529 firmware_log.Metric(mnprops.RMS_ERR.format(axis), rms_err_mm), 530 ]) 531 self.max_err_mm[axis] = max_err_mm 532 self.rms_err_mm[axis] = rms_err_mm 533 534 def check(self, packets, variation=None): 535 """Check if the packets conforms to specified criteria.""" 536 self.init_check(packets) 537 self._log_details_and_metrics(variation) 538 # Calculate the score based on the max error 539 max_err = max(self.max_err_mm.values()) 540 self.vlog.score = self.fc.mf.grade(max_err) 541 return self.vlog 542 543 544class RangeValidator(BaseValidator): 545 """Validator to check the observed (x, y) positions should be within 546 the range of reported min/max values. 547 548 Example: 549 To check the range of observed edge-to-edge positions: 550 RangeValidator('<= 0.05, ~ +0.05') 551 """ 552 553 def __init__(self, criteria_str, mf=None, device=None): 554 self.name = self.__class__.__name__ 555 super(RangeValidator, self).__init__(criteria_str, mf, device, 556 self.name) 557 558 def check(self, packets, variation=None): 559 """Check the left/right or top/bottom range based on the direction.""" 560 self.init_check(packets) 561 valid_directions = [GV.CL, GV.CR, GV.CT, GV.CB] 562 Range = namedtuple('Range', valid_directions) 563 actual_range = Range(*self.packets.get_range()) 564 spec_range = Range(self.device.axis_x.min, self.device.axis_x.max, 565 self.device.axis_y.min, self.device.axis_y.max) 566 567 direction = self.get_direction_in_variation(variation) 568 if direction in valid_directions: 569 actual_edge = getattr(actual_range, direction) 570 spec_edge = getattr(spec_range, direction) 571 short_of_range_px = abs(actual_edge - spec_edge) 572 else: 573 err_msg = 'Error: the gesture variation %s is not allowed in %s.' 574 print_and_exit(err_msg % (variation, self.name)) 575 576 axis_spec = (self.device.axis_x if self.is_horizontal(variation) 577 else self.device.axis_y) 578 deviation_ratio = (float(short_of_range_px) / 579 (axis_spec.max - axis_spec.min)) 580 # Convert the direction to edge name. 581 # E.g., direction: center_to_left 582 # edge name: left 583 edge_name = direction.split('_')[-1] 584 metric_name = self.mnprops.RANGE.format(edge_name) 585 short_of_range_mm = self.device.pixel_to_mm_single_axis( 586 short_of_range_px, axis_spec) 587 self.vlog.metrics = [ 588 firmware_log.Metric(metric_name, short_of_range_mm) 589 ] 590 self.log_details('actual: px %s' % str(actual_edge)) 591 self.log_details('spec: px %s' % str(spec_edge)) 592 self.log_details('short of range: %d px == %f mm' % 593 (short_of_range_px, short_of_range_mm)) 594 self.vlog.score = self.fc.mf.grade(deviation_ratio) 595 return self.vlog 596 597 598class CountTrackingIDValidator(BaseValidator): 599 """Validator to check the count of tracking IDs. 600 601 Example: 602 To verify if there is exactly one finger observed: 603 CountTrackingIDValidator('== 1') 604 """ 605 606 def __init__(self, criteria_str, mf=None, device=None): 607 name = self.__class__.__name__ 608 super(CountTrackingIDValidator, self).__init__(criteria_str, mf, 609 device, name) 610 611 def check(self, packets, variation=None): 612 """Check the number of tracking IDs observed.""" 613 self.init_check(packets) 614 615 # Get the actual count of tracking id and log the details. 616 actual_count_tid = self.packets.get_number_contacts() 617 self.log_details('count of trackid IDs: %d' % actual_count_tid) 618 619 # Only keep metrics with the criteria '== N'. 620 # Ignore those with '>= N' which are used to assert that users have 621 # performed correct gestures. As an example, we require that users 622 # tap more than a certain number of times in the drumroll test. 623 if '==' in self.criteria_str: 624 expected_count_tid = int(self.criteria_str.split('==')[-1].strip()) 625 # E.g., expected_count_tid = 2 626 # actual_count_tid could be either smaller (e.g., 1) or 627 # larger (e.g., 3). 628 metric_value = (actual_count_tid, expected_count_tid) 629 metric_name = self.mnprops.TID 630 self.vlog.metrics = [firmware_log.Metric(metric_name, metric_value)] 631 632 self.vlog.score = self.fc.mf.grade(actual_count_tid) 633 return self.vlog 634 635 636class StationaryFingerValidator(BaseValidator): 637 """Validator to check the count of tracking IDs. 638 639 Example: 640 To verify if the stationary finger specified by the slot does not 641 move larger than a specified radius: 642 StationaryFingerValidator('<= 15 ~ +10') 643 """ 644 645 def __init__(self, criteria, mf=None, device=None, slot=0): 646 name = self.__class__.__name__ 647 super(StationaryFingerValidator, self).__init__(criteria, mf, 648 device, name) 649 self.slot = slot 650 651 def check(self, packets, variation=None): 652 """Check the moving distance of the specified finger.""" 653 self.init_check(packets) 654 max_distance = self.packets.get_max_distance(self.slot, UNIT.MM) 655 msg = 'Max distance slot%d: %d mm' 656 self.log_details(msg % (self.slot, max_distance)) 657 self.vlog.metrics = [ 658 firmware_log.Metric(self.mnprops.MAX_DISTANCE, max_distance) 659 ] 660 self.vlog.score = self.fc.mf.grade(max_distance) 661 return self.vlog 662 663 664class NoGapValidator(BaseValidator): 665 """Validator to make sure that there are no significant gaps in a line. 666 667 Example: 668 To verify if there is exactly one finger observed: 669 NoGapValidator('<= 5, ~ +5', slot=1) 670 """ 671 672 def __init__(self, criteria_str, mf=None, device=None, slot=0): 673 name = self.__class__.__name__ 674 super(NoGapValidator, self).__init__(criteria_str, mf, device, name) 675 self.slot = slot 676 677 def check(self, packets, variation=None): 678 """There should be no significant gaps in a line.""" 679 self.init_check(packets) 680 # Get the largest gap ratio 681 gap_ratio = self.packets.get_largest_gap_ratio(self.slot) 682 msg = 'Largest gap ratio slot%d: %f' 683 self.log_details(msg % (self.slot, gap_ratio)) 684 self.vlog.score = self.fc.mf.grade(gap_ratio) 685 return self.vlog 686 687 688class NoReversedMotionValidator(BaseValidator): 689 """Validator to measure the reversed motions in the specified slots. 690 691 Example: 692 To measure the reversed motions in slot 0: 693 NoReversedMotionValidator('== 0, ~ +20', slots=0) 694 """ 695 def __init__(self, criteria_str, mf=None, device=None, slots=(0,), 696 segments=VAL.MIDDLE): 697 self._segments = segments 698 name = get_derived_name(self.__class__.__name__, segments) 699 self.slots = (slots,) if isinstance(slots, int) else slots 700 parent = super(NoReversedMotionValidator, self) 701 parent.__init__(criteria_str, mf, device, name) 702 703 def _get_reversed_motions(self, slot, direction): 704 """Get the reversed motions opposed to the direction in the slot.""" 705 return self.packets.get_reversed_motions(slot, 706 direction, 707 segment_flag=self._segments, 708 ratio=END_PERCENTAGE) 709 710 def check(self, packets, variation=None): 711 """There should be no reversed motions in a slot.""" 712 self.init_check(packets) 713 sum_reversed_motions = 0 714 direction = self.get_direction_in_variation(variation) 715 for slot in self.slots: 716 # Get the reversed motions. 717 reversed_motions = self._get_reversed_motions(slot, direction) 718 msg = 'Reversed motions slot%d: %s px' 719 self.log_details(msg % (slot, reversed_motions)) 720 sum_reversed_motions += sum(map(abs, reversed_motions.values())) 721 self.vlog.score = self.fc.mf.grade(sum_reversed_motions) 722 return self.vlog 723 724 725class CountPacketsValidator(BaseValidator): 726 """Validator to check the number of packets. 727 728 Example: 729 To verify if there are enough packets received about the first finger: 730 CountPacketsValidator('>= 3, ~ -3', slot=0) 731 """ 732 733 def __init__(self, criteria_str, mf=None, device=None, slot=0): 734 self.name = self.__class__.__name__ 735 super(CountPacketsValidator, self).__init__(criteria_str, mf, device, 736 self.name) 737 self.slot = slot 738 739 def check(self, packets, variation=None): 740 """Check the number of packets in the specified slot.""" 741 self.init_check(packets) 742 # Get the number of packets in that slot 743 actual_count_packets = self.packets.get_num_packets(self.slot) 744 msg = 'Number of packets slot%d: %s' 745 self.log_details(msg % (self.slot, actual_count_packets)) 746 747 # Add the metric for the count of packets 748 expected_count_packets = self.get_threshold(self.criteria_str, '>') 749 assert expected_count_packets, 'Check the criteria of %s' % self.name 750 metric_value = (actual_count_packets, expected_count_packets) 751 metric_name = self.mnprops.COUNT_PACKETS 752 self.vlog.metrics = [firmware_log.Metric(metric_name, metric_value)] 753 754 self.vlog.score = self.fc.mf.grade(actual_count_packets) 755 return self.vlog 756 757 758class PinchValidator(BaseValidator): 759 """Validator to check the pinch to zoom in/out. 760 761 Example: 762 To verify that the two fingers are drawing closer: 763 PinchValidator('>= 200, ~ -100') 764 """ 765 766 def __init__(self, criteria_str, mf=None, device=None): 767 self.name = self.__class__.__name__ 768 super(PinchValidator, self).__init__(criteria_str, mf, device, 769 self.name) 770 771 def check(self, packets, variation): 772 """Check the number of packets in the specified slot.""" 773 self.init_check(packets) 774 # Get the relative motion of the two fingers 775 slots = (0, 1) 776 actual_relative_motion = self.packets.get_relative_motion(slots) 777 if variation == GV.ZOOM_OUT: 778 actual_relative_motion = -actual_relative_motion 779 msg = 'Relative motions of the two fingers: %.2f px' 780 self.log_details(msg % actual_relative_motion) 781 782 # Add the metric for relative motion distance. 783 expected_relative_motion = self.get_threshold(self.criteria_str, '>') 784 assert expected_relative_motion, 'Check the criteria of %s' % self.name 785 metric_value = (actual_relative_motion, expected_relative_motion) 786 metric_name = self.mnprops.PINCH 787 self.vlog.metrics = [firmware_log.Metric(metric_name, metric_value)] 788 789 self.vlog.score = self.fc.mf.grade(actual_relative_motion) 790 return self.vlog 791 792 793class PhysicalClickValidator(BaseValidator): 794 """Validator to check the events generated by physical clicks 795 796 Example: 797 To verify the events generated by a one-finger physical click 798 PhysicalClickValidator('== 1', fingers=1) 799 """ 800 801 def __init__(self, criteria_str, fingers, mf=None, device=None): 802 self.criteria_str = criteria_str 803 self.name = self.__class__.__name__ 804 super(PhysicalClickValidator, self).__init__(criteria_str, mf, device, 805 self.name) 806 self.fingers = fingers 807 808 def _get_expected_number(self): 809 """Get the expected number of counts from the criteria string. 810 811 E.g., criteria_str: '== 1' 812 """ 813 try: 814 expected_count = int(self.criteria_str.split('==')[-1].strip()) 815 except Exception, e: 816 print 'Error: %s in the criteria string of %s' % (e, self.name) 817 exit(1) 818 return expected_count 819 820 def _add_metrics(self): 821 """Add metrics""" 822 fingers = self.fingers 823 raw_click_count = self.packets.get_raw_physical_clicks() 824 825 # This is for the metric: 826 # "of the n clicks, the % of clicks with the correct finger IDs" 827 correct_click_count = self.packets.get_correct_physical_clicks(fingers) 828 value_with_TIDs = (correct_click_count, raw_click_count) 829 name_with_TIDs = self.mnprops.CLICK_CHECK_TIDS.format(self.fingers) 830 831 # This is for the metric: "% of finger IDs with a click" 832 expected_click_count = self._get_expected_number() 833 value_clicks = (raw_click_count, expected_click_count) 834 name_clicks = self.mnprops.CLICK_CHECK_CLICK.format(self.fingers) 835 836 self.vlog.metrics = [ 837 firmware_log.Metric(name_with_TIDs, value_with_TIDs), 838 firmware_log.Metric(name_clicks, value_clicks), 839 ] 840 841 return value_with_TIDs 842 843 def check(self, packets, variation=None): 844 """Check the number of packets in the specified slot.""" 845 self.init_check(packets) 846 correct_click_count, raw_click_count = self._add_metrics() 847 # Get the number of physical clicks made with the specified number 848 # of fingers. 849 msg = 'Count of %d-finger physical clicks: %s' 850 self.log_details(msg % (self.fingers, correct_click_count)) 851 self.log_details('Count of physical clicks: %d' % raw_click_count) 852 self.vlog.score = self.fc.mf.grade(correct_click_count) 853 return self.vlog 854 855 856class DrumrollValidator(BaseValidator): 857 """Validator to check the drumroll problem. 858 859 All points from the same finger should be within 2 circles of radius X mm 860 (e.g. 2 mm) 861 862 Example: 863 To verify that the max radius of all minimal enclosing circles generated 864 by alternately tapping the index and middle fingers is within 2.0 mm. 865 DrumrollValidator('<= 2.0') 866 """ 867 868 def __init__(self, criteria_str, mf=None, device=None): 869 name = self.__class__.__name__ 870 super(DrumrollValidator, self).__init__(criteria_str, mf, device, name) 871 872 def check(self, packets, variation=None): 873 """The moving distance of the points in any tracking ID should be 874 within the specified value. 875 """ 876 self.init_check(packets) 877 # For each tracking ID, compute the minimal enclosing circles, 878 # rocs = (radius_of_circle1, radius_of_circle2) 879 # Return a list of such minimal enclosing circles of all tracking IDs. 880 rocs = self.packets.get_list_of_rocs_of_all_tracking_ids() 881 max_radius = max(rocs) 882 self.log_details('Max radius: %.2f mm' % max_radius) 883 metric_name = self.mnprops.CIRCLE_RADIUS 884 self.vlog.metrics = [firmware_log.Metric(metric_name, roc) 885 for roc in rocs] 886 self.vlog.score = self.fc.mf.grade(max_radius) 887 return self.vlog 888 889 890class NoLevelJumpValidator(BaseValidator): 891 """Validator to check if there are level jumps 892 893 When a user draws a horizontal line with thumb edge or a fat finger, 894 the line could comprise a horizontal line segment followed by another 895 horizontal line segment (or just dots) one level up or down, and then 896 another horizontal line segment again at different horizontal level, etc. 897 This validator is implemented to detect such level jumps. 898 899 Such level jumps could also occur when drawing vertical or diagonal lines. 900 901 Example: 902 To verify the level jumps in a one-finger tracking gesture: 903 NoLevelJumpValidator('<= 10, ~ +30', slots[0,]) 904 where slots[0,] represent the slots with numbers larger than slot 0. 905 This kind of representation is required because when the thumb edge or 906 a fat finger is used, due to the difficulty in handling it correctly 907 in the touch device firmware, the tracking IDs and slot IDs may keep 908 changing. We would like to analyze all such slots. 909 """ 910 911 def __init__(self, criteria_str, mf=None, device=None, slots=0): 912 name = self.__class__.__name__ 913 super(NoLevelJumpValidator, self).__init__(criteria_str, mf, device, 914 name) 915 self.slots = slots 916 917 def check(self, packets, variation=None): 918 """Check if there are level jumps.""" 919 self.init_check(packets) 920 # Get the displacements of the slots. 921 slots = self.slots[0] 922 displacements = self.packets.get_displacements_for_slots(slots) 923 924 # Iterate through the collected tracking IDs 925 jumps = [] 926 for tid in displacements: 927 slot = displacements[tid][MTB.SLOT] 928 for axis in AXIS.LIST: 929 disp = displacements[tid][axis] 930 jump = self.packets.get_largest_accumulated_level_jumps(disp) 931 jumps.append(jump) 932 msg = ' accu jump (%d %s): %d px' 933 self.log_details(msg % (slot, axis, jump)) 934 935 # Get the largest accumulated level jump 936 max_jump = max(jumps) if jumps else 0 937 msg = 'Max accu jump: %d px' 938 self.log_details(msg % (max_jump)) 939 self.vlog.score = self.fc.mf.grade(max_jump) 940 return self.vlog 941 942 943class ReportRateValidator(BaseValidator): 944 """Validator to check the report rate. 945 946 Example: 947 To verify that the report rate is around 80 Hz. It gets 0 points 948 if the report rate drops below 60 Hz. 949 ReportRateValidator('== 80 ~ -20') 950 """ 951 952 def __init__(self, criteria_str, finger=None, mf=None, device=None, 953 chop_off_pauses=True): 954 """Initialize ReportRateValidator 955 956 @param criteria_str: the criteria string 957 @param finger: the ith contact if not None. When set to None, it means 958 to examine all packets. 959 @param mf: the fuzzy member function to use 960 @param device: the touch device 961 """ 962 self.name = self.__class__.__name__ 963 self.criteria_str = criteria_str 964 self.finger = finger 965 if finger is not None: 966 msg = '%s: finger = %d (It is required that finger >= 0.)' 967 assert finger >= 0, msg % (self.name, finger) 968 self.chop_off_pauses = chop_off_pauses 969 super(ReportRateValidator, self).__init__(criteria_str, mf, device, 970 self.name) 971 972 def _chop_off_both_ends(self, points, distance): 973 """Chop off both ends of segments such that the points in the remaining 974 middle segment are distant from both ends by more than the specified 975 distance. 976 977 When performing a gesture such as finger tracking, it is possible 978 that the finger will stay stationary for a while before it actually 979 starts moving. Likewise, it is also possible that the finger may stay 980 stationary before the finger leaves the touch surface. We would like 981 to chop off the stationary segments. 982 983 Note: if distance is 0, the effect is equivalent to keep all points. 984 985 @param points: a list of Points 986 @param distance: the distance within which the points are chopped off 987 """ 988 def _find_index(points, distance, reversed_flag=False): 989 """Find the first index of the point whose distance with the 990 first point is larger than the specified distance. 991 992 @param points: a list of Points 993 @param distance: the distance 994 @param reversed_flag: indicates if the points needs to be reversed 995 """ 996 points_len = len(points) 997 if reversed_flag: 998 points = reversed(points) 999 1000 ref_point = None 1001 for i, p in enumerate(points): 1002 if ref_point is None: 1003 ref_point = p 1004 if ref_point.distance(p) >= distance: 1005 return (points_len - i - 1) if reversed_flag else i 1006 1007 return None 1008 1009 # There must be extra points in addition to the first and the last point 1010 if len(points) <= 2: 1011 return None 1012 1013 begin_moving_index = _find_index(points, distance, reversed_flag=False) 1014 end_moving_index = _find_index(points, distance, reversed_flag=True) 1015 1016 if (begin_moving_index is None or end_moving_index is None or 1017 begin_moving_index > end_moving_index): 1018 return None 1019 return [begin_moving_index, end_moving_index] 1020 1021 def _add_report_rate_metrics2(self): 1022 """Calculate and add the metrics about report rate. 1023 1024 Three metrics are required. 1025 - % of time intervals that are > (1/60) second 1026 - average time interval 1027 - max time interval 1028 1029 """ 1030 import test_conf as conf 1031 1032 if self.finger: 1033 finger_list = [self.finger] 1034 else: 1035 ordered_finger_paths_dict = self.packets.get_ordered_finger_paths() 1036 finger_list = range(len(ordered_finger_paths_dict)) 1037 1038 # distance: the minimal moving distance within which the points 1039 # at both ends will be chopped off 1040 distance = conf.MIN_MOVING_DISTANCE if self.chop_off_pauses else 0 1041 1042 # Derive the middle moving segment in which the finger(s) 1043 # moves significantly. 1044 begin_time = float('infinity') 1045 end_time = float('-infinity') 1046 for finger in finger_list: 1047 list_t = self.packets.get_ordered_finger_path(finger, 'syn_time') 1048 points = self.packets.get_ordered_finger_path(finger, 'point') 1049 middle = self._chop_off_both_ends(points, distance) 1050 if middle: 1051 this_begin_index, this_end_index = middle 1052 this_begin_time = list_t[this_begin_index] 1053 this_end_time = list_t[this_end_index] 1054 begin_time = min(begin_time, this_begin_time) 1055 end_time = max(end_time, this_end_time) 1056 1057 if (begin_time == float('infinity') or end_time == float('-infinity') 1058 or end_time <= begin_time): 1059 print 'Warning: %s: cannot derive a moving segment.' % self.name 1060 print 'begin_time: ', begin_time 1061 print 'end_time: ', end_time 1062 return 1063 1064 # Get the list of SYN_REPORT time in the middle moving segment. 1065 list_syn_time = filter(lambda t: t >= begin_time and t <= end_time, 1066 self.packets.get_list_syn_time(self.finger)) 1067 1068 # Each packet consists of a list of events of which The last one is 1069 # the sync event. The unit of sync_intervals is ms. 1070 sync_intervals = [1000.0 * (list_syn_time[i + 1] - list_syn_time[i]) 1071 for i in range(len(list_syn_time) - 1)] 1072 1073 max_report_interval = conf.max_report_interval 1074 1075 # Calculate the metrics and add them to vlog. 1076 long_intervals = [s for s in sync_intervals if s > max_report_interval] 1077 metric_long_intervals = (len(long_intervals), len(sync_intervals)) 1078 ave_interval = sum(sync_intervals) / len(sync_intervals) 1079 max_interval = max(sync_intervals) 1080 1081 name_long_intervals_pct = self.mnprops.LONG_INTERVALS.format( 1082 '%.2f' % max_report_interval) 1083 name_ave_time_interval = self.mnprops.AVE_TIME_INTERVAL 1084 name_max_time_interval = self.mnprops.MAX_TIME_INTERVAL 1085 1086 self.vlog.metrics = [ 1087 firmware_log.Metric(name_long_intervals_pct, metric_long_intervals), 1088 firmware_log.Metric(self.mnprops.AVE_TIME_INTERVAL, ave_interval), 1089 firmware_log.Metric(self.mnprops.MAX_TIME_INTERVAL, max_interval), 1090 ] 1091 1092 self.log_details('%s: %f' % (self.mnprops.AVE_TIME_INTERVAL, 1093 ave_interval)) 1094 self.log_details('%s: %f' % (self.mnprops.MAX_TIME_INTERVAL, 1095 max_interval)) 1096 self.log_details('# long intervals > %s ms: %d' % 1097 (self.mnprops.max_report_interval_str, 1098 len(long_intervals))) 1099 self.log_details('# total intervals: %d' % len(sync_intervals)) 1100 1101 def _get_report_rate(self, list_syn_time): 1102 """Get the report rate in Hz from the list of syn_time. 1103 1104 @param list_syn_time: a list of SYN_REPORT time instants 1105 """ 1106 if len(list_syn_time) <= 1: 1107 return 0 1108 duration = list_syn_time[-1] - list_syn_time[0] 1109 num_packets = len(list_syn_time) - 1 1110 report_rate = float(num_packets) / duration 1111 return report_rate 1112 1113 def check(self, packets, variation=None): 1114 """The Report rate should be within the specified range.""" 1115 self.init_check(packets) 1116 # Get the list of syn_time based on the specified finger. 1117 list_syn_time = self.packets.get_list_syn_time(self.finger) 1118 # Get the report rate 1119 self.report_rate = self._get_report_rate(list_syn_time) 1120 self._add_report_rate_metrics2() 1121 self.vlog.score = self.fc.mf.grade(self.report_rate) 1122 return self.vlog 1123