1# Copyright 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import dpkt
6import logging
7import re
8
9from autotest_lib.client.bin import test
10from autotest_lib.client.common_lib import error
11from autotest_lib.client.common_lib.cros.tendo import peerd_config
12from autotest_lib.client.cros import chrooted_avahi
13from autotest_lib.client.cros.netprotos import interface_host
14from autotest_lib.client.cros.netprotos import zeroconf
15from autotest_lib.client.cros.tendo import peerd_dbus_helper
16
17
18class peerd_AdvertiseServices(test.test):
19    """Test that peerd can correctly advertise services over mDNS."""
20    version = 1
21
22    ANY_VALUE = object()  # Use reference equality for wildcard.
23    FAKE_HOST_HOSTNAME = 'test-host'
24    TEST_TIMEOUT_SECONDS = 30
25    TEST_SERVICE_ID = 'test-service-0'
26    TEST_SERVICE_INFO = {'some_data': 'a value',
27                          'other_data': 'another value'}
28    TEST_SERVICE_PORT = 8080
29    SERBUS_SERVICE_ID = 'serbus'
30    SERBUS_SERVICE_INFO = {
31            'ver': '1.0',
32            'id': ANY_VALUE,
33            'services': r'(.+\.)?' + TEST_SERVICE_ID + r'(\..+)?',
34    }
35    SERBUS_SERVICE_PORT = 0
36
37
38    def initialize(self):
39        # Make sure these are initiallized to None in case we throw
40        # during self.initialize().
41        self._chrooted_avahi = None
42        self._peerd = None
43        self._host = None
44        self._zc_listener = None
45        self._chrooted_avahi = chrooted_avahi.ChrootedAvahi()
46        self._chrooted_avahi.start()
47        # Start up a cleaned up peerd with really verbose logging.
48        self._peerd = peerd_dbus_helper.make_helper(
49                peerd_config.PeerdConfig(verbosity_level=3))
50        # Listen on our half of the interface pair for mDNS advertisements.
51        self._host = interface_host.InterfaceHost(
52                self._chrooted_avahi.unchrooted_interface_name)
53        self._zc_listener = zeroconf.ZeroconfDaemon(self._host,
54                                                    self.FAKE_HOST_HOSTNAME)
55        # The queries for hostname/dns_domain are IPCs and therefor relatively
56        # expensive.  Do them just once.
57        hostname = self._chrooted_avahi.hostname
58        dns_domain = self._chrooted_avahi.dns_domain
59        if not hostname or not dns_domain:
60            raise error.TestFail('Failed to get hostname/domain from avahi.')
61        self._dns_domain = dns_domain
62        self._hostname = '%s.%s' % (hostname, dns_domain)
63
64
65    def cleanup(self):
66        for obj in (self._chrooted_avahi,
67                    self._host,
68                    self._peerd):
69            if obj is not None:
70                obj.close()
71
72
73    def _check_txt_record_data(self, expected_data, actual_data):
74        # Labels in the TXT record should be 1:1 with our service info.
75        expected_entries = expected_data.copy()
76        for entry in actual_data:
77            # All labels should be key/value pairs.
78            if entry.find('=') < 0:
79                raise error.TestFail('All TXT entries should have = separator, '
80                                     'but got: %s' % entry)
81            k, v = entry.split('=', 1)
82            if k not in expected_entries:
83                raise error.TestFail('Unexpected TXT entry key: %s' % k)
84            if (expected_entries[k] != self.ANY_VALUE and
85                    not re.match(expected_entries[k], v)):
86                # We're going to return False here rather than fail the test
87                # for one tricky reason: in the root serbus record, we may
88                # find that the service list does not match our expectation
89                # since other daemons may be advertising services via peerd.
90                # We need to basically wait for our test service to show up.
91                logging.warning('Expected TXT value to match %s for '
92                                'entry=%s but got value=%r instead.',
93                                expected_entries[k], k, v)
94                return False
95            expected_entries.pop(k)
96        if expected_entries:
97            # Raise a detailed exception here, rather than return false.
98            raise error.TestFail('Missing entries from TXT: %r' %
99                                 expected_entries)
100        return True
101
102
103    def _ask_for_record(self, record_name, record_type):
104        """Ask for a record, and query for it if we don't have it.
105
106        @param record_name: string name of record (e.g. the complete host name
107                            for A records.
108        @param record_type: one of dpkt.dns.DNS_*.
109        @return list of matching records.
110
111        """
112        found_records = self._zc_listener.cached_results(
113                record_name, record_type)
114        if len(found_records) > 1:
115            logging.warning('Found multiple records with name=%s and type=%r',
116                            record_name, record_type)
117        if found_records:
118            logging.debug('Found record with name=%s, type=%r, value=%r.',
119                          record_name, record_type, found_records[0].data)
120            return found_records[0]
121        logging.debug('Did not see record with name=%s and type=%r',
122                      record_name, record_type)
123        desired_records = [(record_name, record_type)]
124        self._zc_listener.send_request(desired_records)
125        return None
126
127
128    def _found_service_records(self, service_id, service_info, service_port):
129        PTR_name = '_%s._tcp.%s' % (service_id, self._dns_domain)
130        record_PTR = self._ask_for_record(PTR_name, dpkt.dns.DNS_PTR)
131        if not record_PTR:
132            return False
133        # Great, we know the PTR, make sure that we can also get the SRV and
134        # TXT entries.
135        TXT_name = SRV_name = record_PTR.data
136        record_SRV = self._ask_for_record(SRV_name, dpkt.dns.DNS_SRV)
137        if record_SRV is None:
138            return False
139        if (record_SRV.data[0] != self._hostname or
140                record_SRV.data[3] != service_port):
141            raise error.TestFail('Expected SRV record data %r but got %r' %
142                                 ((self._hostname, service_port),
143                                  record_SRV.data))
144        # TXT should exist.
145        record_TXT = self._ask_for_record(TXT_name, dpkt.dns.DNS_TXT)
146        if (record_TXT is None or
147                not self._check_txt_record_data(service_info, record_TXT.data)):
148            return False
149        return True
150
151
152    def _found_desired_records(self):
153        """Verifies that avahi has all the records we care about.
154
155        Asks the |self._zc_listener| for records we expect to correspond
156        to our test service.  Will trigger queries if we don't find the
157        expected records.
158
159        @return True if we have all expected records, False otherwise.
160
161        """
162        logging.debug('Looking for records for %s.', self._hostname)
163        # First, check that Avahi is doing the simple things and publishing
164        # an A record.
165        record_A = self._ask_for_record(self._hostname, dpkt.dns.DNS_A)
166        if (record_A is None or
167                record_A.data != self._chrooted_avahi.avahi_interface_addr):
168            return False
169        logging.debug('Found A record, looking for serbus records.')
170        # If we can see Avahi publishing that it's there, check that it has
171        # appropriate entries for its serbus master record.
172        if not self._found_service_records(self.SERBUS_SERVICE_ID,
173                                           self.SERBUS_SERVICE_INFO,
174                                           self.SERBUS_SERVICE_PORT):
175            return False
176        logging.debug('Found serbus records, looking for service records.')
177        # We also expect the subservices we've added to exist.
178        if not self._found_service_records(self.TEST_SERVICE_ID,
179                                           self.TEST_SERVICE_INFO,
180                                           self.TEST_SERVICE_PORT):
181            return False
182        logging.debug('Found all desired records.')
183        return True
184
185
186    def run_once(self):
187        # Tell peerd about this exciting new service we have.
188        self._peerd.expose_service(
189                self.TEST_SERVICE_ID,
190                self.TEST_SERVICE_INFO,
191                mdns_options={'port': self.TEST_SERVICE_PORT})
192        # Wait for advertisements of that service to appear from avahi.
193        logging.info('Waiting to receive mDNS advertisements of '
194                     'peerd services.')
195        success, duration = self._host.run_until(self._found_desired_records,
196                                                 self.TEST_TIMEOUT_SECONDS)
197        if not success:
198            raise error.TestFail('Did not receive mDNS advertisements in time.')
199