1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import collections
6import dpkt
7import logging
8import socket
9import time
10
11
12DnsRecord = collections.namedtuple('DnsResult', ['rrname', 'rrtype', 'data', 'ts'])
13
14MDNS_IP_ADDR = '224.0.0.251'
15MDNS_PORT = 5353
16
17# Value to | to a class value to signal cache flush.
18DNS_CACHE_FLUSH = 0x8000
19
20# When considering SRV records, clients are supposed to unilaterally prefer
21# numerically lower priorities, then pick probabilistically by weight.
22# See RFC2782.
23# An arbitrary number that will fit in 16 bits.
24DEFAULT_PRIORITY = 500
25# An arbitrary number that will fit in 16 bits.
26DEFAULT_WEIGHT = 500
27
28def _RR_equals(rra, rrb):
29    """Returns whether the two dpkt.dns.DNS.RR objects are equal."""
30    # Compare all the members present in either object and on any RR object.
31    keys = set(rra.__dict__.keys() + rrb.__dict__.keys() +
32               dpkt.dns.DNS.RR.__slots__)
33    # On RR objects, rdata is packed based on the other members and the final
34    # packed string depends on other RR and Q elements on the same DNS/mDNS
35    # packet.
36    keys.discard('rdata')
37    for key in keys:
38        if hasattr(rra, key) != hasattr(rrb, key):
39            return False
40        if not hasattr(rra, key):
41            continue
42        if key == 'cls':
43          # cls attribute should be masked for the cache flush bit.
44          if (getattr(rra, key) & ~DNS_CACHE_FLUSH !=
45                getattr(rrb, key) & ~DNS_CACHE_FLUSH):
46              return False
47        else:
48          if getattr(rra, key) != getattr(rrb, key):
49              return False
50    return True
51
52
53class ZeroconfDaemon(object):
54    """Implements a simulated Zeroconf daemon running on the given host.
55
56    This class implements part of the Multicast DNS RFC 6762 able to simulate
57    a host exposing services or consuming services over mDNS.
58    """
59    def __init__(self, host, hostname, domain='local'):
60        """Initializes the ZeroconfDameon running on the given host.
61
62        For the purposes of the Zeroconf implementation, a host must have a
63        hostname and a domain that defaults to 'local'. The ZeroconfDaemon will
64        by default advertise the host address it is running on, which is
65        required by some services.
66
67        @param host: The Host instance where this daemon runs on.
68        @param hostname: A string representing the hostname
69        """
70        self._host = host
71        self._hostname = hostname
72        self._domain = domain
73        self._response_ttl = 60 # Default TTL in seconds.
74
75        self._a_records = {} # Local A records.
76        self._srv_records = {} # Local SRV records.
77        self._ptr_records = {} # Local PTR records.
78        self._txt_records = {} # Local TXT records.
79
80        # dict() of name --> (dict() of type --> (dict() of data --> timeout))
81        # For example: _peer_records['somehost.local'][dpkt.dns.DNS_A] \
82        #     ['192.168.0.1'] = time.time() + 3600
83        self._peer_records = {}
84
85        # Register the host address locally.
86        self.register_A(self.full_hostname, host.ip_addr)
87
88        # Attend all the traffic to the mDNS port (unicast, multicast or
89        # broadcast).
90        self._sock = host.socket(socket.AF_INET, socket.SOCK_DGRAM)
91        self._sock.listen(MDNS_IP_ADDR, MDNS_PORT, self._mdns_request)
92
93        # Observer list for new responses.
94        self._answer_callbacks = []
95
96
97    def __del__(self):
98        self._sock.close()
99
100
101    @property
102    def host(self):
103        """The Host object where this daemon is running."""
104        return self._host
105
106
107    @property
108    def hostname(self):
109        """The hostname part within a domain."""
110        return self._hostname
111
112
113    @property
114    def domain(self):
115        """The domain where the given hostname is running."""
116        return self._domain
117
118
119    @property
120    def full_hostname(self):
121        """The full hostname designation including host and domain name."""
122        return self._hostname + '.' + self._domain
123
124
125    def _mdns_request(self, data, addr, port):
126        """Handles a mDNS multicast packet.
127
128        This method will generate and send a mDNS response to any query
129        for which it has new authoritative information. Called by the Simulator
130        as a callback for every mDNS received packet.
131
132        @param data: The string contained on the UDP message.
133        @param addr: The address where the message comes from.
134        @param port: The port number where the message comes from.
135        """
136        # Parse the mDNS request using dpkt's DNS module.
137        mdns = dpkt.dns.DNS(data)
138        if mdns.op == 0x0000: # Query
139            QUERY_HANDLERS = {
140                dpkt.dns.DNS_A: self._process_A,
141                dpkt.dns.DNS_PTR: self._process_PTR,
142                dpkt.dns.DNS_TXT: self._process_TXT,
143                dpkt.dns.DNS_SRV: self._process_SRV,
144            }
145
146            answers = []
147            for q in mdns.qd: # Query entries
148                if q.type in QUERY_HANDLERS:
149                    answers += QUERY_HANDLERS[q.type](q)
150                elif q.type == dpkt.dns.DNS_ANY:
151                    # Special type matching any known type.
152                    for _, handler in QUERY_HANDLERS.iteritems():
153                        answers += handler(q)
154            # Remove all the already known answers from the list.
155            answers = [ans for ans in answers if not any(True
156                for known_ans in mdns.an if _RR_equals(known_ans, ans))]
157
158            self._send_answers(answers)
159
160        # Always process the received authoritative answers.
161        answers = mdns.ns
162
163        # Process the answers for response packets.
164        if mdns.op == 0x8400: # Standard response
165            answers.extend(mdns.an)
166
167        if answers:
168            cur_time = time.time()
169            new_answers = []
170            for rr in answers: # Answers RRs
171                # dpkt decodes the information on different fields depending on
172                # the response type.
173                if rr.type == dpkt.dns.DNS_A:
174                    data = socket.inet_ntoa(rr.ip)
175                elif rr.type == dpkt.dns.DNS_PTR:
176                    data = rr.ptrname
177                elif rr.type == dpkt.dns.DNS_TXT:
178                    data = tuple(rr.text) # Convert the list to a hashable tuple
179                elif rr.type == dpkt.dns.DNS_SRV:
180                    data = rr.srvname, rr.priority, rr.weight, rr.port
181                else:
182                    continue # Ignore unsupported records.
183                if not rr.name in self._peer_records:
184                    self._peer_records[rr.name] = {}
185                # Start a new cache or clear the existing if required.
186                if not rr.type in self._peer_records[rr.name] or (
187                        rr.cls & DNS_CACHE_FLUSH):
188                    self._peer_records[rr.name][rr.type] = {}
189
190                new_answers.append((rr.type, rr.name, data))
191                cached_ans = self._peer_records[rr.name][rr.type]
192                rr_timeout = cur_time + rr.ttl
193                # Update the answer timeout if already cached.
194                if data in cached_ans:
195                    cached_ans[data] = max(cached_ans[data], rr_timeout)
196                else:
197                    cached_ans[data] = rr_timeout
198            if new_answers:
199                for cbk in self._answer_callbacks:
200                    cbk(new_answers)
201
202
203    def clear_cache(self):
204        """Discards all the cached records."""
205        self._peer_records = {}
206
207
208    def _send_answers(self, answers):
209        """Send a mDNS reply with the provided answers.
210
211        This method uses the undelying Host to send an IP packet with a mDNS
212        response containing the list of answers of the type dpkt.dns.DNS.RR.
213        If the list is empty, no packet is sent.
214
215        @param answers: The list of answers to send.
216        """
217        if not answers:
218            return
219        logging.debug('Sending response with answers: %r.', answers)
220        resp_dns = dpkt.dns.DNS(
221            op = dpkt.dns.DNS_AA, # Authoritative Answer.
222            rcode = dpkt.dns.DNS_RCODE_NOERR,
223            an = answers)
224        # This property modifies the "op" field:
225        resp_dns.qr = dpkt.dns.DNS_R, # Response.
226        self._sock.send(str(resp_dns), MDNS_IP_ADDR, MDNS_PORT)
227
228
229    ### RFC 2782 - RR for specifying the location of services (DNS SRV).
230    def register_SRV(self, service, proto, priority, weight, port):
231        """Publishes the SRV specified record.
232
233        A SRV record defines a service on a port of a host with given properties
234        like priority and weight. The service has a name of the form
235        "service.proto.domain". The target host, this is, the host where the
236        announced service is running on is set to the host where this zeroconf
237        daemon is running, "hostname.domain".
238
239        @param service: A string with the service name.
240        @param proto: A string with the protocol name, for example "_tcp".
241        @param priority: The service priority number as defined by RFC2782.
242        @param weight: The service weight number as defined by RFC2782.
243        @param port: The port number where the service is running on.
244        """
245        srvname = service + '.' + proto + '.' + self._domain
246        self._srv_records[srvname] = priority, weight, port
247
248
249    def _process_SRV(self, q):
250        """Process a SRV query provided in |q|.
251
252        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_SRV.
253        @return: A list of dns.DNS.RR responses to the provided query that can
254        be empty.
255        """
256        if not q.name in self._srv_records:
257            return []
258        priority, weight, port = self._srv_records[q.name]
259        full_hostname = self._hostname + '.' + self._domain
260        ans = dpkt.dns.DNS.RR(
261            type = dpkt.dns.DNS_SRV,
262            cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
263            ttl = self._response_ttl,
264            name = q.name,
265            srvname = full_hostname,
266            priority = priority,
267            weight = weight,
268            port = port)
269        # The target host (srvname) requires to send an A record with its IP
270        # address. We do this as if a query for it was sent.
271        a_qry = dpkt.dns.DNS.Q(name=full_hostname, type=dpkt.dns.DNS_A)
272        return [ans] + self._process_A(a_qry)
273
274
275    ### RFC 1035 - 3.4.1, Domains Names - A (IPv4 address).
276    def register_A(self, hostname, ip_addr):
277        """Registers an Address record (A) pointing to the given IP addres.
278
279        Records registered with method are assumed authoritative.
280
281        @param hostname: The full host name, for example, "somehost.local".
282        @param ip_addr: The IPv4 address of the host, for example, "192.0.1.1".
283        """
284        if not hostname in self._a_records:
285            self._a_records[hostname] = []
286        self._a_records[hostname].append(socket.inet_aton(ip_addr))
287
288
289    def _process_A(self, q):
290        """Process an A query provided in |q|.
291
292        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_A.
293        @return: A list of dns.DNS.RR responses to the provided query that can
294        be empty.
295        """
296        if not q.name in self._a_records:
297            return []
298        answers = []
299        for ip_addr in self._a_records[q.name]:
300            answers.append(dpkt.dns.DNS.RR(
301                type = dpkt.dns.DNS_A,
302                cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
303                ttl = self._response_ttl,
304                name = q.name,
305                ip = ip_addr))
306        return answers
307
308
309    ### RFC 1035 - 3.3.12, Domain names - PTR (domain name pointer).
310    def register_PTR(self, domain, destination):
311        """Register a domain pointer record.
312
313        A domain pointer record is simply a pointer to a hostname on the domain.
314
315        @param domain: A domain name that can include a proto name, for
316        example, "_workstation._tcp.local".
317        @param destination: The hostname inside the given domain, for example,
318        "my-desktop".
319        """
320        if not domain in self._ptr_records:
321            self._ptr_records[domain] = []
322        self._ptr_records[domain].append(destination)
323
324
325    def _process_PTR(self, q):
326        """Process a PTR query provided in |q|.
327
328        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_PTR.
329        @return: A list of dns.DNS.RR responses to the provided query that can
330        be empty.
331        """
332        if not q.name in self._ptr_records:
333            return []
334        answers = []
335        for dest in self._ptr_records[q.name]:
336            answers.append(dpkt.dns.DNS.RR(
337                type = dpkt.dns.DNS_PTR,
338                cls = dpkt.dns.DNS_IN, # Don't cache flush for PTR records.
339                ttl = self._response_ttl,
340                name = q.name,
341                ptrname = dest + '.' + q.name))
342        return answers
343
344
345    ### RFC 1035 - 3.3.14, Domain names - TXT (descriptive text).
346    def register_TXT(self, domain, txt_list, announce=False):
347        """Register a TXT record on a domain with given list of strings.
348
349        A TXT record can hold any list of text entries whos format depends on
350        the domain. This method replaces any previous TXT record previously
351        registered for the given domain.
352
353        @param domain: A domain name that normally can include a proto name and
354        a service or host name.
355        @param txt_list: A list of strings.
356        @param announce: If True, the method will also announce the changes
357        on the network.
358        """
359        self._txt_records[domain] = txt_list
360        if announce:
361            self._send_answers(self._process_TXT(dpkt.dns.DNS.Q(name=domain)))
362
363
364    def _process_TXT(self, q):
365        """Process a TXT query provided in |q|.
366
367        @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_TXT.
368        @return: A list of dns.DNS.RR responses to the provided query that can
369        be empty.
370        """
371        if not q.name in self._txt_records:
372            return []
373        text_list = self._txt_records[q.name]
374        answer = dpkt.dns.DNS.RR(
375            type = dpkt.dns.DNS_TXT,
376            cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
377            ttl = self._response_ttl,
378            name = q.name,
379            text = text_list)
380        return [answer]
381
382
383    def register_service(self, unique_prefix, service_type,
384                         protocol, port, txt_list):
385        """Register a service in the Avahi style.
386
387        Avahi exposes a convenient set of methods for manipulating "services"
388        which are a trio of PTR, SRV, and TXT records.  This is a similar
389        helper method for our daemon.
390
391        @param unique_prefix: string unique prefix of service (part of the
392                              canonical name).
393        @param service_type: string type of service (e.g. '_privet').
394        @param protocol: string protocol to use for service (e.g. '_tcp').
395        @param port: IP port of service (e.g. 53).
396        @param txt_list: list of txt records (e.g. ['vers=1.0', 'foo']).
397        """
398        service_name = '.'.join([unique_prefix, service_type])
399        fq_service_name = '.'.join([service_name, protocol, self._domain])
400        logging.debug('Registering service=%s on port=%d with txt records=%r',
401                      fq_service_name, port, txt_list)
402        self.register_SRV(
403                service_name, protocol, DEFAULT_PRIORITY, DEFAULT_WEIGHT, port)
404        self.register_PTR('.'.join([service_type, protocol, self._domain]),
405                          unique_prefix)
406        self.register_TXT(fq_service_name, txt_list)
407
408
409    def cached_results(self, rrname, rrtype, timestamp=None):
410        """Return all the cached results for the requested rrname and rrtype.
411
412        This method is used to request all the received mDNS answers present
413        on the cache that were valid at the provided timestamp or later.
414        Answers received before this timestamp whose TTL isn't long enough to
415        make them valid at the timestamp aren't returned. On the other hand,
416        answers received *after* the provided timestamp will always be
417        considered, even if they weren't known at the provided timestamp point.
418        A timestamp of None will return them all.
419
420        This method allows to retrieve "volatile" answers with a TTL of zero.
421        According to the RFC, these answers should be only considered for the
422        "ongoing" request. To do this, call this method after a few seconds (the
423        request timeout) after calling the send_request() method, passing to
424        this method the returned timestamp.
425
426        @param rrname: The requested domain name.
427        @param rrtype: The DNS record type. For example, dpkt.dns.DNS_TXT.
428        @param timestamp: The request timestamp. See description.
429        @return: The list of matching records of the form (rrname, rrtype, data,
430                 timeout).
431        """
432        if timestamp is None:
433            timestamp = 0
434        if not rrname in self._peer_records:
435            return []
436        if not rrtype in self._peer_records[rrname]:
437            return []
438        res = []
439        for data, data_ts in self._peer_records[rrname][rrtype].iteritems():
440            if data_ts >= timestamp:
441                res.append(DnsRecord(rrname, rrtype, data, data_ts))
442        return res
443
444
445    def send_request(self, queries):
446        """Sends a request for the provided rrname and rrtype.
447
448        All the known and valid answers for this request will be included in the
449        non authoritative list of known answers together with the request. This
450        is recommended by the RFC and avoid unnecessary responses.
451
452        @param queries: A list of pairs (rrname, rrtype) where rrname is the
453        domain name you are requesting for and the rrtype is the DNS record
454        type. For example, ('somehost.local', dpkt.dns.DNS_ANY).
455        @return: The timestamp where this request is sent. See cached_results().
456        """
457        queries = [dpkt.dns.DNS.Q(name=rrname, type=rrtype)
458                for rrname, rrtype in queries]
459        # TODO(deymo): Inlcude the already known answers on the request.
460        answers = []
461        mdns = dpkt.dns.DNS(
462            op = dpkt.dns.DNS_QUERY,
463            qd = queries,
464            an = answers)
465        self._sock.send(str(mdns), MDNS_IP_ADDR, MDNS_PORT)
466        return time.time()
467
468
469    def add_answer_observer(self, callback):
470        """Adds the callback to the list of observers for new answers.
471
472        @param callback: A callable object accepting a list of tuples (rrname,
473        rrtype, data) where rrname is the domain name, rrtype the DNS record
474        type and data is the information associated with the answers, similar to
475        what cached_results() returns.
476        """
477        self._answer_callbacks.append(callback)
478