1# Copyright (C) 2001-2007, 2009, 2010 Nominum, Inc.
2#
3# Permission to use, copy, modify, and distribute this software and its
4# documentation for any purpose with or without fee is hereby granted,
5# provided that the above copyright notice and this permission notice
6# appear in all copies.
7#
8# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16"""Help for building DNS wire format messages"""
17
18import cStringIO
19import struct
20import random
21import time
22
23import dns.exception
24import dns.tsig
25
26QUESTION = 0
27ANSWER = 1
28AUTHORITY = 2
29ADDITIONAL = 3
30
31class Renderer(object):
32    """Helper class for building DNS wire-format messages.
33
34    Most applications can use the higher-level L{dns.message.Message}
35    class and its to_wire() method to generate wire-format messages.
36    This class is for those applications which need finer control
37    over the generation of messages.
38
39    Typical use::
40
41        r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512)
42        r.add_question(qname, qtype, qclass)
43        r.add_rrset(dns.renderer.ANSWER, rrset_1)
44        r.add_rrset(dns.renderer.ANSWER, rrset_2)
45        r.add_rrset(dns.renderer.AUTHORITY, ns_rrset)
46        r.add_edns(0, 0, 4096)
47        r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_1)
48        r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_2)
49        r.write_header()
50        r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac)
51        wire = r.get_wire()
52
53    @ivar output: where rendering is written
54    @type output: cStringIO.StringIO object
55    @ivar id: the message id
56    @type id: int
57    @ivar flags: the message flags
58    @type flags: int
59    @ivar max_size: the maximum size of the message
60    @type max_size: int
61    @ivar origin: the origin to use when rendering relative names
62    @type origin: dns.name.Name object
63    @ivar compress: the compression table
64    @type compress: dict
65    @ivar section: the section currently being rendered
66    @type section: int (dns.renderer.QUESTION, dns.renderer.ANSWER,
67    dns.renderer.AUTHORITY, or dns.renderer.ADDITIONAL)
68    @ivar counts: list of the number of RRs in each section
69    @type counts: int list of length 4
70    @ivar mac: the MAC of the rendered message (if TSIG was used)
71    @type mac: string
72    """
73
74    def __init__(self, id=None, flags=0, max_size=65535, origin=None):
75        """Initialize a new renderer.
76
77        @param id: the message id
78        @type id: int
79        @param flags: the DNS message flags
80        @type flags: int
81        @param max_size: the maximum message size; the default is 65535.
82        If rendering results in a message greater than I{max_size},
83        then L{dns.exception.TooBig} will be raised.
84        @type max_size: int
85        @param origin: the origin to use when rendering relative names
86        @type origin: dns.name.Namem or None.
87        """
88
89        self.output = cStringIO.StringIO()
90        if id is None:
91            self.id = random.randint(0, 65535)
92        else:
93            self.id = id
94        self.flags = flags
95        self.max_size = max_size
96        self.origin = origin
97        self.compress = {}
98        self.section = QUESTION
99        self.counts = [0, 0, 0, 0]
100        self.output.write('\x00' * 12)
101        self.mac = ''
102
103    def _rollback(self, where):
104        """Truncate the output buffer at offset I{where}, and remove any
105        compression table entries that pointed beyond the truncation
106        point.
107
108        @param where: the offset
109        @type where: int
110        """
111
112        self.output.seek(where)
113        self.output.truncate()
114        keys_to_delete = []
115        for k, v in self.compress.iteritems():
116            if v >= where:
117                keys_to_delete.append(k)
118        for k in keys_to_delete:
119            del self.compress[k]
120
121    def _set_section(self, section):
122        """Set the renderer's current section.
123
124        Sections must be rendered order: QUESTION, ANSWER, AUTHORITY,
125        ADDITIONAL.  Sections may be empty.
126
127        @param section: the section
128        @type section: int
129        @raises dns.exception.FormError: an attempt was made to set
130        a section value less than the current section.
131        """
132
133        if self.section != section:
134            if self.section > section:
135                raise dns.exception.FormError
136            self.section = section
137
138    def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
139        """Add a question to the message.
140
141        @param qname: the question name
142        @type qname: dns.name.Name
143        @param rdtype: the question rdata type
144        @type rdtype: int
145        @param rdclass: the question rdata class
146        @type rdclass: int
147        """
148
149        self._set_section(QUESTION)
150        before = self.output.tell()
151        qname.to_wire(self.output, self.compress, self.origin)
152        self.output.write(struct.pack("!HH", rdtype, rdclass))
153        after = self.output.tell()
154        if after >= self.max_size:
155            self._rollback(before)
156            raise dns.exception.TooBig
157        self.counts[QUESTION] += 1
158
159    def add_rrset(self, section, rrset, **kw):
160        """Add the rrset to the specified section.
161
162        Any keyword arguments are passed on to the rdataset's to_wire()
163        routine.
164
165        @param section: the section
166        @type section: int
167        @param rrset: the rrset
168        @type rrset: dns.rrset.RRset object
169        """
170
171        self._set_section(section)
172        before = self.output.tell()
173        n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
174        after = self.output.tell()
175        if after >= self.max_size:
176            self._rollback(before)
177            raise dns.exception.TooBig
178        self.counts[section] += n
179
180    def add_rdataset(self, section, name, rdataset, **kw):
181        """Add the rdataset to the specified section, using the specified
182        name as the owner name.
183
184        Any keyword arguments are passed on to the rdataset's to_wire()
185        routine.
186
187        @param section: the section
188        @type section: int
189        @param name: the owner name
190        @type name: dns.name.Name object
191        @param rdataset: the rdataset
192        @type rdataset: dns.rdataset.Rdataset object
193        """
194
195        self._set_section(section)
196        before = self.output.tell()
197        n = rdataset.to_wire(name, self.output, self.compress, self.origin,
198                             **kw)
199        after = self.output.tell()
200        if after >= self.max_size:
201            self._rollback(before)
202            raise dns.exception.TooBig
203        self.counts[section] += n
204
205    def add_edns(self, edns, ednsflags, payload, options=None):
206        """Add an EDNS OPT record to the message.
207
208        @param edns: The EDNS level to use.
209        @type edns: int
210        @param ednsflags: EDNS flag values.
211        @type ednsflags: int
212        @param payload: The EDNS sender's payload field, which is the maximum
213        size of UDP datagram the sender can handle.
214        @type payload: int
215        @param options: The EDNS options list
216        @type options: list of dns.edns.Option instances
217        @see: RFC 2671
218        """
219
220        # make sure the EDNS version in ednsflags agrees with edns
221        ednsflags &= 0xFF00FFFFL
222        ednsflags |= (edns << 16)
223        self._set_section(ADDITIONAL)
224        before = self.output.tell()
225        self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload,
226                                      ednsflags, 0))
227        if not options is None:
228            lstart = self.output.tell()
229            for opt in options:
230                stuff = struct.pack("!HH", opt.otype, 0)
231                self.output.write(stuff)
232                start = self.output.tell()
233                opt.to_wire(self.output)
234                end = self.output.tell()
235                assert end - start < 65536
236                self.output.seek(start - 2)
237                stuff = struct.pack("!H", end - start)
238                self.output.write(stuff)
239                self.output.seek(0, 2)
240            lend = self.output.tell()
241            assert lend - lstart < 65536
242            self.output.seek(lstart - 2)
243            stuff = struct.pack("!H", lend - lstart)
244            self.output.write(stuff)
245            self.output.seek(0, 2)
246        after = self.output.tell()
247        if after >= self.max_size:
248            self._rollback(before)
249            raise dns.exception.TooBig
250        self.counts[ADDITIONAL] += 1
251
252    def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
253                 request_mac, algorithm=dns.tsig.default_algorithm):
254        """Add a TSIG signature to the message.
255
256        @param keyname: the TSIG key name
257        @type keyname: dns.name.Name object
258        @param secret: the secret to use
259        @type secret: string
260        @param fudge: TSIG time fudge
261        @type fudge: int
262        @param id: the message id to encode in the tsig signature
263        @type id: int
264        @param tsig_error: TSIG error code; default is 0.
265        @type tsig_error: int
266        @param other_data: TSIG other data.
267        @type other_data: string
268        @param request_mac: This message is a response to the request which
269        had the specified MAC.
270        @param algorithm: the TSIG algorithm to use
271        @type request_mac: string
272        """
273
274        self._set_section(ADDITIONAL)
275        before = self.output.tell()
276        s = self.output.getvalue()
277        (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
278                                                    keyname,
279                                                    secret,
280                                                    int(time.time()),
281                                                    fudge,
282                                                    id,
283                                                    tsig_error,
284                                                    other_data,
285                                                    request_mac,
286                                                    algorithm=algorithm)
287        keyname.to_wire(self.output, self.compress, self.origin)
288        self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
289                                      dns.rdataclass.ANY, 0, 0))
290        rdata_start = self.output.tell()
291        self.output.write(tsig_rdata)
292        after = self.output.tell()
293        assert after - rdata_start < 65536
294        if after >= self.max_size:
295            self._rollback(before)
296            raise dns.exception.TooBig
297        self.output.seek(rdata_start - 2)
298        self.output.write(struct.pack('!H', after - rdata_start))
299        self.counts[ADDITIONAL] += 1
300        self.output.seek(10)
301        self.output.write(struct.pack('!H', self.counts[ADDITIONAL]))
302        self.output.seek(0, 2)
303
304    def write_header(self):
305        """Write the DNS message header.
306
307        Writing the DNS message header is done asfter all sections
308        have been rendered, but before the optional TSIG signature
309        is added.
310        """
311
312        self.output.seek(0)
313        self.output.write(struct.pack('!HHHHHH', self.id, self.flags,
314                                      self.counts[0], self.counts[1],
315                                      self.counts[2], self.counts[3]))
316        self.output.seek(0, 2)
317
318    def get_wire(self):
319        """Return the wire format message.
320
321        @rtype: string
322        """
323
324        return self.output.getvalue()
325