1# Copyright (C) 2003-2007, 2009, 2010 Nominum, Inc.
2#
3# Permission to use, copy, modify, and distribute this software and its
4# documentation for any purpose with or without fee is hereby granted,
5# provided that the above copyright notice and this permission notice
6# appear in all copies.
7#
8# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16"""Talk to a DNS server."""
17
18from __future__ import generators
19
20import errno
21import select
22import socket
23import struct
24import sys
25import time
26
27import dns.exception
28import dns.inet
29import dns.name
30import dns.message
31import dns.rdataclass
32import dns.rdatatype
33
34class UnexpectedSource(dns.exception.DNSException):
35    """Raised if a query response comes from an unexpected address or port."""
36    pass
37
38class BadResponse(dns.exception.FormError):
39    """Raised if a query response does not respond to the question asked."""
40    pass
41
42def _compute_expiration(timeout):
43    if timeout is None:
44        return None
45    else:
46        return time.time() + timeout
47
48def _wait_for(ir, iw, ix, expiration):
49    done = False
50    while not done:
51        if expiration is None:
52            timeout = None
53        else:
54            timeout = expiration - time.time()
55            if timeout <= 0.0:
56                raise dns.exception.Timeout
57        try:
58            if timeout is None:
59                (r, w, x) = select.select(ir, iw, ix)
60            else:
61                (r, w, x) = select.select(ir, iw, ix, timeout)
62        except select.error, e:
63            if e.args[0] != errno.EINTR:
64                raise e
65        done = True
66        if len(r) == 0 and len(w) == 0 and len(x) == 0:
67            raise dns.exception.Timeout
68
69def _wait_for_readable(s, expiration):
70    _wait_for([s], [], [s], expiration)
71
72def _wait_for_writable(s, expiration):
73    _wait_for([], [s], [s], expiration)
74
75def _addresses_equal(af, a1, a2):
76    # Convert the first value of the tuple, which is a textual format
77    # address into binary form, so that we are not confused by different
78    # textual representations of the same address
79    n1 = dns.inet.inet_pton(af, a1[0])
80    n2 = dns.inet.inet_pton(af, a2[0])
81    return n1 == n2 and a1[1:] == a2[1:]
82
83def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
84        ignore_unexpected=False, one_rr_per_rrset=False):
85    """Return the response obtained after sending a query via UDP.
86
87    @param q: the query
88    @type q: dns.message.Message
89    @param where: where to send the message
90    @type where: string containing an IPv4 or IPv6 address
91    @param timeout: The number of seconds to wait before the query times out.
92    If None, the default, wait forever.
93    @type timeout: float
94    @param port: The port to which to send the message.  The default is 53.
95    @type port: int
96    @param af: the address family to use.  The default is None, which
97    causes the address family to use to be inferred from the form of of where.
98    If the inference attempt fails, AF_INET is used.
99    @type af: int
100    @rtype: dns.message.Message object
101    @param source: source address.  The default is the IPv4 wildcard address.
102    @type source: string
103    @param source_port: The port from which to send the message.
104    The default is 0.
105    @type source_port: int
106    @param ignore_unexpected: If True, ignore responses from unexpected
107    sources.  The default is False.
108    @type ignore_unexpected: bool
109    @param one_rr_per_rrset: Put each RR into its own RRset
110    @type one_rr_per_rrset: bool
111    """
112
113    wire = q.to_wire()
114    if af is None:
115        try:
116            af = dns.inet.af_for_address(where)
117        except:
118            af = dns.inet.AF_INET
119    if af == dns.inet.AF_INET:
120        destination = (where, port)
121        if source is not None:
122            source = (source, source_port)
123    elif af == dns.inet.AF_INET6:
124        destination = (where, port, 0, 0)
125        if source is not None:
126            source = (source, source_port, 0, 0)
127    s = socket.socket(af, socket.SOCK_DGRAM, 0)
128    try:
129        expiration = _compute_expiration(timeout)
130        s.setblocking(0)
131        if source is not None:
132            s.bind(source)
133        _wait_for_writable(s, expiration)
134        s.sendto(wire, destination)
135        while 1:
136            _wait_for_readable(s, expiration)
137            (wire, from_address) = s.recvfrom(65535)
138            if _addresses_equal(af, from_address, destination) or \
139                    (dns.inet.is_multicast(where) and \
140                         from_address[1:] == destination[1:]):
141                break
142            if not ignore_unexpected:
143                raise UnexpectedSource('got a response from '
144                                       '%s instead of %s' % (from_address,
145                                                             destination))
146    finally:
147        s.close()
148    r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
149                              one_rr_per_rrset=one_rr_per_rrset)
150    if not q.is_response(r):
151        raise BadResponse
152    return r
153
154def _net_read(sock, count, expiration):
155    """Read the specified number of bytes from sock.  Keep trying until we
156    either get the desired amount, or we hit EOF.
157    A Timeout exception will be raised if the operation is not completed
158    by the expiration time.
159    """
160    s = ''
161    while count > 0:
162        _wait_for_readable(sock, expiration)
163        n = sock.recv(count)
164        if n == '':
165            raise EOFError
166        count = count - len(n)
167        s = s + n
168    return s
169
170def _net_write(sock, data, expiration):
171    """Write the specified data to the socket.
172    A Timeout exception will be raised if the operation is not completed
173    by the expiration time.
174    """
175    current = 0
176    l = len(data)
177    while current < l:
178        _wait_for_writable(sock, expiration)
179        current += sock.send(data[current:])
180
181def _connect(s, address):
182    try:
183        s.connect(address)
184    except socket.error:
185        (ty, v) = sys.exc_info()[:2]
186        if v[0] != errno.EINPROGRESS and \
187               v[0] != errno.EWOULDBLOCK and \
188               v[0] != errno.EALREADY:
189            raise v
190
191def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
192        one_rr_per_rrset=False):
193    """Return the response obtained after sending a query via TCP.
194
195    @param q: the query
196    @type q: dns.message.Message object
197    @param where: where to send the message
198    @type where: string containing an IPv4 or IPv6 address
199    @param timeout: The number of seconds to wait before the query times out.
200    If None, the default, wait forever.
201    @type timeout: float
202    @param port: The port to which to send the message.  The default is 53.
203    @type port: int
204    @param af: the address family to use.  The default is None, which
205    causes the address family to use to be inferred from the form of of where.
206    If the inference attempt fails, AF_INET is used.
207    @type af: int
208    @rtype: dns.message.Message object
209    @param source: source address.  The default is the IPv4 wildcard address.
210    @type source: string
211    @param source_port: The port from which to send the message.
212    The default is 0.
213    @type source_port: int
214    @param one_rr_per_rrset: Put each RR into its own RRset
215    @type one_rr_per_rrset: bool
216    """
217
218    wire = q.to_wire()
219    if af is None:
220        try:
221            af = dns.inet.af_for_address(where)
222        except:
223            af = dns.inet.AF_INET
224    if af == dns.inet.AF_INET:
225        destination = (where, port)
226        if source is not None:
227            source = (source, source_port)
228    elif af == dns.inet.AF_INET6:
229        destination = (where, port, 0, 0)
230        if source is not None:
231            source = (source, source_port, 0, 0)
232    s = socket.socket(af, socket.SOCK_STREAM, 0)
233    try:
234        expiration = _compute_expiration(timeout)
235        s.setblocking(0)
236        if source is not None:
237            s.bind(source)
238        _connect(s, destination)
239
240        l = len(wire)
241
242        # copying the wire into tcpmsg is inefficient, but lets us
243        # avoid writev() or doing a short write that would get pushed
244        # onto the net
245        tcpmsg = struct.pack("!H", l) + wire
246        _net_write(s, tcpmsg, expiration)
247        ldata = _net_read(s, 2, expiration)
248        (l,) = struct.unpack("!H", ldata)
249        wire = _net_read(s, l, expiration)
250    finally:
251        s.close()
252    r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
253                              one_rr_per_rrset=one_rr_per_rrset)
254    if not q.is_response(r):
255        raise BadResponse
256    return r
257
258def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
259        timeout=None, port=53, keyring=None, keyname=None, relativize=True,
260        af=None, lifetime=None, source=None, source_port=0, serial=0,
261        use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
262    """Return a generator for the responses to a zone transfer.
263
264    @param where: where to send the message
265    @type where: string containing an IPv4 or IPv6 address
266    @param zone: The name of the zone to transfer
267    @type zone: dns.name.Name object or string
268    @param rdtype: The type of zone transfer.  The default is
269    dns.rdatatype.AXFR.
270    @type rdtype: int or string
271    @param rdclass: The class of the zone transfer.  The default is
272    dns.rdatatype.IN.
273    @type rdclass: int or string
274    @param timeout: The number of seconds to wait for each response message.
275    If None, the default, wait forever.
276    @type timeout: float
277    @param port: The port to which to send the message.  The default is 53.
278    @type port: int
279    @param keyring: The TSIG keyring to use
280    @type keyring: dict
281    @param keyname: The name of the TSIG key to use
282    @type keyname: dns.name.Name object or string
283    @param relativize: If True, all names in the zone will be relativized to
284    the zone origin.  It is essential that the relativize setting matches
285    the one specified to dns.zone.from_xfr().
286    @type relativize: bool
287    @param af: the address family to use.  The default is None, which
288    causes the address family to use to be inferred from the form of of where.
289    If the inference attempt fails, AF_INET is used.
290    @type af: int
291    @param lifetime: The total number of seconds to spend doing the transfer.
292    If None, the default, then there is no limit on the time the transfer may
293    take.
294    @type lifetime: float
295    @rtype: generator of dns.message.Message objects.
296    @param source: source address.  The default is the IPv4 wildcard address.
297    @type source: string
298    @param source_port: The port from which to send the message.
299    The default is 0.
300    @type source_port: int
301    @param serial: The SOA serial number to use as the base for an IXFR diff
302    sequence (only meaningful if rdtype == dns.rdatatype.IXFR).
303    @type serial: int
304    @param use_udp: Use UDP (only meaningful for IXFR)
305    @type use_udp: bool
306    @param keyalgorithm: The TSIG algorithm to use; defaults to
307    dns.tsig.default_algorithm
308    @type keyalgorithm: string
309    """
310
311    if isinstance(zone, (str, unicode)):
312        zone = dns.name.from_text(zone)
313    if isinstance(rdtype, str):
314        rdtype = dns.rdatatype.from_text(rdtype)
315    q = dns.message.make_query(zone, rdtype, rdclass)
316    if rdtype == dns.rdatatype.IXFR:
317        rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
318                                    '. . %u 0 0 0 0' % serial)
319        q.authority.append(rrset)
320    if not keyring is None:
321        q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
322    wire = q.to_wire()
323    if af is None:
324        try:
325            af = dns.inet.af_for_address(where)
326        except:
327            af = dns.inet.AF_INET
328    if af == dns.inet.AF_INET:
329        destination = (where, port)
330        if source is not None:
331            source = (source, source_port)
332    elif af == dns.inet.AF_INET6:
333        destination = (where, port, 0, 0)
334        if source is not None:
335            source = (source, source_port, 0, 0)
336    if use_udp:
337        if rdtype != dns.rdatatype.IXFR:
338            raise ValueError('cannot do a UDP AXFR')
339        s = socket.socket(af, socket.SOCK_DGRAM, 0)
340    else:
341        s = socket.socket(af, socket.SOCK_STREAM, 0)
342    s.setblocking(0)
343    if source is not None:
344        s.bind(source)
345    expiration = _compute_expiration(lifetime)
346    _connect(s, destination)
347    l = len(wire)
348    if use_udp:
349        _wait_for_writable(s, expiration)
350        s.send(wire)
351    else:
352        tcpmsg = struct.pack("!H", l) + wire
353        _net_write(s, tcpmsg, expiration)
354    done = False
355    soa_rrset = None
356    soa_count = 0
357    if relativize:
358        origin = zone
359        oname = dns.name.empty
360    else:
361        origin = None
362        oname = zone
363    tsig_ctx = None
364    first = True
365    while not done:
366        mexpiration = _compute_expiration(timeout)
367        if mexpiration is None or mexpiration > expiration:
368            mexpiration = expiration
369        if use_udp:
370            _wait_for_readable(s, expiration)
371            (wire, from_address) = s.recvfrom(65535)
372        else:
373            ldata = _net_read(s, 2, mexpiration)
374            (l,) = struct.unpack("!H", ldata)
375            wire = _net_read(s, l, mexpiration)
376        r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
377                                  xfr=True, origin=origin, tsig_ctx=tsig_ctx,
378                                  multi=True, first=first,
379                                  one_rr_per_rrset=(rdtype==dns.rdatatype.IXFR))
380        tsig_ctx = r.tsig_ctx
381        first = False
382        answer_index = 0
383        delete_mode = False
384        expecting_SOA = False
385        if soa_rrset is None:
386            if not r.answer or r.answer[0].name != oname:
387                raise dns.exception.FormError
388            rrset = r.answer[0]
389            if rrset.rdtype != dns.rdatatype.SOA:
390                raise dns.exception.FormError("first RRset is not an SOA")
391            answer_index = 1
392            soa_rrset = rrset.copy()
393            if rdtype == dns.rdatatype.IXFR:
394                if soa_rrset[0].serial == serial:
395                    #
396                    # We're already up-to-date.
397                    #
398                    done = True
399                else:
400                    expecting_SOA = True
401        #
402        # Process SOAs in the answer section (other than the initial
403        # SOA in the first message).
404        #
405        for rrset in r.answer[answer_index:]:
406            if done:
407                raise dns.exception.FormError("answers after final SOA")
408            if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
409                if expecting_SOA:
410                    if rrset[0].serial != serial:
411                        raise dns.exception.FormError("IXFR base serial mismatch")
412                    expecting_SOA = False
413                elif rdtype == dns.rdatatype.IXFR:
414                    delete_mode = not delete_mode
415                if rrset == soa_rrset and not delete_mode:
416                    done = True
417            elif expecting_SOA:
418                #
419                # We made an IXFR request and are expecting another
420                # SOA RR, but saw something else, so this must be an
421                # AXFR response.
422                #
423                rdtype = dns.rdatatype.AXFR
424                expecting_SOA = False
425        if done and q.keyring and not r.had_tsig:
426            raise dns.exception.FormError("missing TSIG")
427        yield r
428    s.close()
429