1# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import unittest
6
7import dpkt
8import fake_host
9import socket
10import zeroconf
11
12
13FAKE_HOSTNAME = 'fakehost1'
14
15FAKE_IPADDR = '192.168.11.22'
16
17
18class TestZeroconfDaemon(unittest.TestCase):
19    """Test class for ZeroconfDaemon."""
20
21    def setUp(self):
22        self._host = fake_host.FakeHost(FAKE_IPADDR)
23        self._zero = zeroconf.ZeroconfDaemon(self._host, FAKE_HOSTNAME)
24
25
26    def _query_A(self, name):
27        """Returns the list of A records matching the given name.
28
29        @param name: A domain name.
30        @return a list of dpkt.dns.DNS.RR objects, one for each matching record.
31        """
32        q = dpkt.dns.DNS.Q(name=name, type=dpkt.dns.DNS_A)
33        return self._zero._process_A(q)
34
35
36    def testRegisterService(self):
37        """Tests that we get appropriate records after registering a service."""
38        SERVICE_PORT = 9
39        SERVICE_TXT_LIST = ['lies=lies']
40        self._zero.register_service('unique_prefix', '_service_type',
41                                    '_tcp', SERVICE_PORT, SERVICE_TXT_LIST)
42        name = '_service_type._tcp.local'
43        fq_name = 'unique_prefix.' + name
44        # Issue SRV, PTR, and TXT queries
45        q_srv = dpkt.dns.DNS.Q(name=fq_name, type=dpkt.dns.DNS_SRV)
46        q_txt = dpkt.dns.DNS.Q(name=fq_name, type=dpkt.dns.DNS_TXT)
47        q_ptr = dpkt.dns.DNS.Q(name=name, type=dpkt.dns.DNS_PTR)
48        ptr_responses = self._zero._process_PTR(q_ptr)
49        srv_responses = self._zero._process_SRV(q_srv)
50        txt_responses = self._zero._process_TXT(q_txt)
51        self.assertTrue(ptr_responses)
52        self.assertTrue(srv_responses)
53        self.assertTrue(txt_responses)
54        ptr_resp = ptr_responses[0]
55        srv_resp = [resp for resp in srv_responses
56                    if resp.type == dpkt.dns.DNS_SRV][0]
57        txt_resp = txt_responses[0]
58        # Check that basic things are right.
59        self.assertEqual(fq_name, ptr_resp.ptrname)
60        self.assertEqual(FAKE_HOSTNAME + '.' + self._zero.domain,
61                         srv_resp.srvname)
62        self.assertEqual(SERVICE_PORT, srv_resp.port)
63        self.assertEqual(SERVICE_TXT_LIST, txt_resp.text)
64
65
66    def testProperties(self):
67        """Test the initial properties set by the constructor."""
68        self.assertEqual(self._zero.host, self._host)
69        self.assertEqual(self._zero.hostname, FAKE_HOSTNAME)
70        self.assertEqual(self._zero.domain, 'local') # Default domain
71        self.assertEqual(self._zero.full_hostname, FAKE_HOSTNAME + '.local')
72
73
74    def testSocketInit(self):
75        """Test that the constructor listens for mDNS traffic."""
76
77        # Should create an UDP socket and bind it to the mDNS address and port.
78        self.assertEqual(len(self._host._sockets), 1)
79        sock = self._host._sockets[0]
80
81        self.assertEqual(sock._family, socket.AF_INET) # IPv4
82        self.assertEqual(sock._sock_type, socket.SOCK_DGRAM) # UDP
83
84        # Check it is listening for UDP packets on the mDNS address and port.
85        self.assertTrue(sock._bound)
86        self.assertEqual(sock._bind_ip_addr, '224.0.0.251') # mDNS address
87        self.assertEqual(sock._bind_port, 5353) # mDNS port
88        self.assertTrue(callable(sock._bind_recv_callback))
89
90
91    def testRecordsInit(self):
92        """Test the A record of the host is registered."""
93        host_A = self._query_A(self._zero.full_hostname)
94        self.assertGreater(len(host_A), 0)
95
96        record = host_A[0]
97        # Check the hostname and the packed IP address.
98        self.assertEqual(record.name, self._zero.full_hostname)
99        self.assertEqual(record.ip, socket.inet_aton(self._host.ip_addr))
100
101
102    def testDoubleTXTProcessing(self):
103        """Test when more than one TXT record is present in a packet.
104
105        A mDNS packet can include several answer records for several domains and
106        record type. A corner case found on the field presents a mDNS packet
107        with two TXT records for the same domain name on the same packet on its
108        authoritative answers section while the packet itself is a query.
109        """
110        # Build the mDNS packet with two TXT records.
111        domain_name = 'other_host.local'
112        answers = [
113                dpkt.dns.DNS.RR(
114                        type = dpkt.dns.DNS_TXT,
115                        cls = dpkt.dns.DNS_IN,
116                        ttl = 120,
117                        name = domain_name,
118                        text = ['one', 'two']),
119                dpkt.dns.DNS.RR(
120                        type = dpkt.dns.DNS_TXT,
121                        cls = dpkt.dns.DNS_IN,
122                        ttl = 120,
123                        name = domain_name,
124                        text = ['two'])]
125        # The packet is a query packet, with extra answers on the autoritative
126        # section.
127        mdns = dpkt.dns.DNS(
128                op = dpkt.dns.DNS_QUERY, # Standard query
129                rcode = dpkt.dns.DNS_RCODE_NOERR,
130                q = [],
131                an = [],
132                ns = answers)
133
134        # Record the new answers received on the answer_calls list.
135        answer_calls = []
136        self._zero.add_answer_observer(lambda args: answer_calls.extend(args))
137
138        # Send the packet to the registered callback.
139        sock = self._host._sockets[0]
140        cbk = sock._bind_recv_callback
141        cbk(str(mdns), '1234', 5353)
142
143        # Check that the answers callback is called with all the answers in the
144        # received order.
145        self.assertEqual(len(answer_calls), 2)
146        ans1, ans2 = answer_calls # Each ans is a (rrtype, rrname, data)
147        self.assertEqual(ans1[2], ('one', 'two'))
148        self.assertEqual(ans2[2], ('two',))
149
150        # Check that the two records were cached.
151        records = self._zero.cached_results(domain_name, dpkt.dns.DNS_TXT)
152        self.assertEqual(len(records), 2)
153
154
155if __name__ == '__main__':
156    unittest.main()
157