1# Copyright (C) 2001-2007, 2009, 2010 Nominum, Inc.
2#
3# Permission to use, copy, modify, and distribute this software and its
4# documentation for any purpose with or without fee is hereby granted,
5# provided that the above copyright notice and this permission notice
6# appear in all copies.
7#
8# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16"""DNS Names.
17
18@var root: The DNS root name.
19@type root: dns.name.Name object
20@var empty: The empty DNS name.
21@type empty: dns.name.Name object
22"""
23
24import cStringIO
25import struct
26import sys
27
28if sys.hexversion >= 0x02030000:
29    import encodings.idna
30
31import dns.exception
32
33NAMERELN_NONE = 0
34NAMERELN_SUPERDOMAIN = 1
35NAMERELN_SUBDOMAIN = 2
36NAMERELN_EQUAL = 3
37NAMERELN_COMMONANCESTOR = 4
38
39class EmptyLabel(dns.exception.SyntaxError):
40    """Raised if a label is empty."""
41    pass
42
43class BadEscape(dns.exception.SyntaxError):
44    """Raised if an escaped code in a text format name is invalid."""
45    pass
46
47class BadPointer(dns.exception.FormError):
48    """Raised if a compression pointer points forward instead of backward."""
49    pass
50
51class BadLabelType(dns.exception.FormError):
52    """Raised if the label type of a wire format name is unknown."""
53    pass
54
55class NeedAbsoluteNameOrOrigin(dns.exception.DNSException):
56    """Raised if an attempt is made to convert a non-absolute name to
57    wire when there is also a non-absolute (or missing) origin."""
58    pass
59
60class NameTooLong(dns.exception.FormError):
61    """Raised if a name is > 255 octets long."""
62    pass
63
64class LabelTooLong(dns.exception.SyntaxError):
65    """Raised if a label is > 63 octets long."""
66    pass
67
68class AbsoluteConcatenation(dns.exception.DNSException):
69    """Raised if an attempt is made to append anything other than the
70    empty name to an absolute name."""
71    pass
72
73class NoParent(dns.exception.DNSException):
74    """Raised if an attempt is made to get the parent of the root name
75    or the empty name."""
76    pass
77
78_escaped = {
79    '"' : True,
80    '(' : True,
81    ')' : True,
82    '.' : True,
83    ';' : True,
84    '\\' : True,
85    '@' : True,
86    '$' : True
87    }
88
89def _escapify(label):
90    """Escape the characters in label which need it.
91    @returns: the escaped string
92    @rtype: string"""
93    text = ''
94    for c in label:
95        if c in _escaped:
96            text += '\\' + c
97        elif ord(c) > 0x20 and ord(c) < 0x7F:
98            text += c
99        else:
100            text += '\\%03d' % ord(c)
101    return text
102
103def _validate_labels(labels):
104    """Check for empty labels in the middle of a label sequence,
105    labels that are too long, and for too many labels.
106    @raises NameTooLong: the name as a whole is too long
107    @raises LabelTooLong: an individual label is too long
108    @raises EmptyLabel: a label is empty (i.e. the root label) and appears
109    in a position other than the end of the label sequence"""
110
111    l = len(labels)
112    total = 0
113    i = -1
114    j = 0
115    for label in labels:
116        ll = len(label)
117        total += ll + 1
118        if ll > 63:
119            raise LabelTooLong
120        if i < 0 and label == '':
121            i = j
122        j += 1
123    if total > 255:
124        raise NameTooLong
125    if i >= 0 and i != l - 1:
126        raise EmptyLabel
127
128class Name(object):
129    """A DNS name.
130
131    The dns.name.Name class represents a DNS name as a tuple of labels.
132    Instances of the class are immutable.
133
134    @ivar labels: The tuple of labels in the name. Each label is a string of
135    up to 63 octets."""
136
137    __slots__ = ['labels']
138
139    def __init__(self, labels):
140        """Initialize a domain name from a list of labels.
141        @param labels: the labels
142        @type labels: any iterable whose values are strings
143        """
144
145        super(Name, self).__setattr__('labels', tuple(labels))
146        _validate_labels(self.labels)
147
148    def __setattr__(self, name, value):
149        raise TypeError("object doesn't support attribute assignment")
150
151    def is_absolute(self):
152        """Is the most significant label of this name the root label?
153        @rtype: bool
154        """
155
156        return len(self.labels) > 0 and self.labels[-1] == ''
157
158    def is_wild(self):
159        """Is this name wild?  (I.e. Is the least significant label '*'?)
160        @rtype: bool
161        """
162
163        return len(self.labels) > 0 and self.labels[0] == '*'
164
165    def __hash__(self):
166        """Return a case-insensitive hash of the name.
167        @rtype: int
168        """
169
170        h = 0L
171        for label in self.labels:
172            for c in label:
173                h += ( h << 3 ) + ord(c.lower())
174        return int(h % sys.maxint)
175
176    def fullcompare(self, other):
177        """Compare two names, returning a 3-tuple (relation, order, nlabels).
178
179        I{relation} describes the relation ship between the names,
180        and is one of: dns.name.NAMERELN_NONE,
181        dns.name.NAMERELN_SUPERDOMAIN, dns.name.NAMERELN_SUBDOMAIN,
182        dns.name.NAMERELN_EQUAL, or dns.name.NAMERELN_COMMONANCESTOR
183
184        I{order} is < 0 if self < other, > 0 if self > other, and ==
185        0 if self == other.  A relative name is always less than an
186        absolute name.  If both names have the same relativity, then
187        the DNSSEC order relation is used to order them.
188
189        I{nlabels} is the number of significant labels that the two names
190        have in common.
191        """
192
193        sabs = self.is_absolute()
194        oabs = other.is_absolute()
195        if sabs != oabs:
196            if sabs:
197                return (NAMERELN_NONE, 1, 0)
198            else:
199                return (NAMERELN_NONE, -1, 0)
200        l1 = len(self.labels)
201        l2 = len(other.labels)
202        ldiff = l1 - l2
203        if ldiff < 0:
204            l = l1
205        else:
206            l = l2
207
208        order = 0
209        nlabels = 0
210        namereln = NAMERELN_NONE
211        while l > 0:
212            l -= 1
213            l1 -= 1
214            l2 -= 1
215            label1 = self.labels[l1].lower()
216            label2 = other.labels[l2].lower()
217            if label1 < label2:
218                order = -1
219                if nlabels > 0:
220                    namereln = NAMERELN_COMMONANCESTOR
221                return (namereln, order, nlabels)
222            elif label1 > label2:
223                order = 1
224                if nlabels > 0:
225                    namereln = NAMERELN_COMMONANCESTOR
226                return (namereln, order, nlabels)
227            nlabels += 1
228        order = ldiff
229        if ldiff < 0:
230            namereln = NAMERELN_SUPERDOMAIN
231        elif ldiff > 0:
232            namereln = NAMERELN_SUBDOMAIN
233        else:
234            namereln = NAMERELN_EQUAL
235        return (namereln, order, nlabels)
236
237    def is_subdomain(self, other):
238        """Is self a subdomain of other?
239
240        The notion of subdomain includes equality.
241        @rtype: bool
242        """
243
244        (nr, o, nl) = self.fullcompare(other)
245        if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
246            return True
247        return False
248
249    def is_superdomain(self, other):
250        """Is self a superdomain of other?
251
252        The notion of subdomain includes equality.
253        @rtype: bool
254        """
255
256        (nr, o, nl) = self.fullcompare(other)
257        if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
258            return True
259        return False
260
261    def canonicalize(self):
262        """Return a name which is equal to the current name, but is in
263        DNSSEC canonical form.
264        @rtype: dns.name.Name object
265        """
266
267        return Name([x.lower() for x in self.labels])
268
269    def __eq__(self, other):
270        if isinstance(other, Name):
271            return self.fullcompare(other)[1] == 0
272        else:
273            return False
274
275    def __ne__(self, other):
276        if isinstance(other, Name):
277            return self.fullcompare(other)[1] != 0
278        else:
279            return True
280
281    def __lt__(self, other):
282        if isinstance(other, Name):
283            return self.fullcompare(other)[1] < 0
284        else:
285            return NotImplemented
286
287    def __le__(self, other):
288        if isinstance(other, Name):
289            return self.fullcompare(other)[1] <= 0
290        else:
291            return NotImplemented
292
293    def __ge__(self, other):
294        if isinstance(other, Name):
295            return self.fullcompare(other)[1] >= 0
296        else:
297            return NotImplemented
298
299    def __gt__(self, other):
300        if isinstance(other, Name):
301            return self.fullcompare(other)[1] > 0
302        else:
303            return NotImplemented
304
305    def __repr__(self):
306        return '<DNS name ' + self.__str__() + '>'
307
308    def __str__(self):
309        return self.to_text(False)
310
311    def to_text(self, omit_final_dot = False):
312        """Convert name to text format.
313        @param omit_final_dot: If True, don't emit the final dot (denoting the
314        root label) for absolute names.  The default is False.
315        @rtype: string
316        """
317
318        if len(self.labels) == 0:
319            return '@'
320        if len(self.labels) == 1 and self.labels[0] == '':
321            return '.'
322        if omit_final_dot and self.is_absolute():
323            l = self.labels[:-1]
324        else:
325            l = self.labels
326        s = '.'.join(map(_escapify, l))
327        return s
328
329    def to_unicode(self, omit_final_dot = False):
330        """Convert name to Unicode text format.
331
332        IDN ACE lables are converted to Unicode.
333
334        @param omit_final_dot: If True, don't emit the final dot (denoting the
335        root label) for absolute names.  The default is False.
336        @rtype: string
337        """
338
339        if len(self.labels) == 0:
340            return u'@'
341        if len(self.labels) == 1 and self.labels[0] == '':
342            return u'.'
343        if omit_final_dot and self.is_absolute():
344            l = self.labels[:-1]
345        else:
346            l = self.labels
347        s = u'.'.join([encodings.idna.ToUnicode(_escapify(x)) for x in l])
348        return s
349
350    def to_digestable(self, origin=None):
351        """Convert name to a format suitable for digesting in hashes.
352
353        The name is canonicalized and converted to uncompressed wire format.
354
355        @param origin: If the name is relative and origin is not None, then
356        origin will be appended to it.
357        @type origin: dns.name.Name object
358        @raises NeedAbsoluteNameOrOrigin: All names in wire format are
359        absolute.  If self is a relative name, then an origin must be supplied;
360        if it is missing, then this exception is raised
361        @rtype: string
362        """
363
364        if not self.is_absolute():
365            if origin is None or not origin.is_absolute():
366                raise NeedAbsoluteNameOrOrigin
367            labels = list(self.labels)
368            labels.extend(list(origin.labels))
369        else:
370            labels = self.labels
371        dlabels = ["%s%s" % (chr(len(x)), x.lower()) for x in labels]
372        return ''.join(dlabels)
373
374    def to_wire(self, file = None, compress = None, origin = None):
375        """Convert name to wire format, possibly compressing it.
376
377        @param file: the file where the name is emitted (typically
378        a cStringIO file).  If None, a string containing the wire name
379        will be returned.
380        @type file: file or None
381        @param compress: The compression table.  If None (the default) names
382        will not be compressed.
383        @type compress: dict
384        @param origin: If the name is relative and origin is not None, then
385        origin will be appended to it.
386        @type origin: dns.name.Name object
387        @raises NeedAbsoluteNameOrOrigin: All names in wire format are
388        absolute.  If self is a relative name, then an origin must be supplied;
389        if it is missing, then this exception is raised
390        """
391
392        if file is None:
393            file = cStringIO.StringIO()
394            want_return = True
395        else:
396            want_return = False
397
398        if not self.is_absolute():
399            if origin is None or not origin.is_absolute():
400                raise NeedAbsoluteNameOrOrigin
401            labels = list(self.labels)
402            labels.extend(list(origin.labels))
403        else:
404            labels = self.labels
405        i = 0
406        for label in labels:
407            n = Name(labels[i:])
408            i += 1
409            if not compress is None:
410                pos = compress.get(n)
411            else:
412                pos = None
413            if not pos is None:
414                value = 0xc000 + pos
415                s = struct.pack('!H', value)
416                file.write(s)
417                break
418            else:
419                if not compress is None and len(n) > 1:
420                    pos = file.tell()
421                    if pos < 0xc000:
422                        compress[n] = pos
423                l = len(label)
424                file.write(chr(l))
425                if l > 0:
426                    file.write(label)
427        if want_return:
428            return file.getvalue()
429
430    def __len__(self):
431        """The length of the name (in labels).
432        @rtype: int
433        """
434
435        return len(self.labels)
436
437    def __getitem__(self, index):
438        return self.labels[index]
439
440    def __getslice__(self, start, stop):
441        return self.labels[start:stop]
442
443    def __add__(self, other):
444        return self.concatenate(other)
445
446    def __sub__(self, other):
447        return self.relativize(other)
448
449    def split(self, depth):
450        """Split a name into a prefix and suffix at depth.
451
452        @param depth: the number of labels in the suffix
453        @type depth: int
454        @raises ValueError: the depth was not >= 0 and <= the length of the
455        name.
456        @returns: the tuple (prefix, suffix)
457        @rtype: tuple
458        """
459
460        l = len(self.labels)
461        if depth == 0:
462            return (self, dns.name.empty)
463        elif depth == l:
464            return (dns.name.empty, self)
465        elif depth < 0 or depth > l:
466            raise ValueError('depth must be >= 0 and <= the length of the name')
467        return (Name(self[: -depth]), Name(self[-depth :]))
468
469    def concatenate(self, other):
470        """Return a new name which is the concatenation of self and other.
471        @rtype: dns.name.Name object
472        @raises AbsoluteConcatenation: self is absolute and other is
473        not the empty name
474        """
475
476        if self.is_absolute() and len(other) > 0:
477            raise AbsoluteConcatenation
478        labels = list(self.labels)
479        labels.extend(list(other.labels))
480        return Name(labels)
481
482    def relativize(self, origin):
483        """If self is a subdomain of origin, return a new name which is self
484        relative to origin.  Otherwise return self.
485        @rtype: dns.name.Name object
486        """
487
488        if not origin is None and self.is_subdomain(origin):
489            return Name(self[: -len(origin)])
490        else:
491            return self
492
493    def derelativize(self, origin):
494        """If self is a relative name, return a new name which is the
495        concatenation of self and origin.  Otherwise return self.
496        @rtype: dns.name.Name object
497        """
498
499        if not self.is_absolute():
500            return self.concatenate(origin)
501        else:
502            return self
503
504    def choose_relativity(self, origin=None, relativize=True):
505        """Return a name with the relativity desired by the caller.  If
506        origin is None, then self is returned.  Otherwise, if
507        relativize is true the name is relativized, and if relativize is
508        false the name is derelativized.
509        @rtype: dns.name.Name object
510        """
511
512        if origin:
513            if relativize:
514                return self.relativize(origin)
515            else:
516                return self.derelativize(origin)
517        else:
518            return self
519
520    def parent(self):
521        """Return the parent of the name.
522        @rtype: dns.name.Name object
523        @raises NoParent: the name is either the root name or the empty name,
524        and thus has no parent.
525        """
526        if self == root or self == empty:
527            raise NoParent
528        return Name(self.labels[1:])
529
530root = Name([''])
531empty = Name([])
532
533def from_unicode(text, origin = root):
534    """Convert unicode text into a Name object.
535
536    Lables are encoded in IDN ACE form.
537
538    @rtype: dns.name.Name object
539    """
540
541    if not isinstance(text, unicode):
542        raise ValueError("input to from_unicode() must be a unicode string")
543    if not (origin is None or isinstance(origin, Name)):
544        raise ValueError("origin must be a Name or None")
545    labels = []
546    label = u''
547    escaping = False
548    edigits = 0
549    total = 0
550    if text == u'@':
551        text = u''
552    if text:
553        if text == u'.':
554            return Name([''])	# no Unicode "u" on this constant!
555        for c in text:
556            if escaping:
557                if edigits == 0:
558                    if c.isdigit():
559                        total = int(c)
560                        edigits += 1
561                    else:
562                        label += c
563                        escaping = False
564                else:
565                    if not c.isdigit():
566                        raise BadEscape
567                    total *= 10
568                    total += int(c)
569                    edigits += 1
570                    if edigits == 3:
571                        escaping = False
572                        label += chr(total)
573            elif c == u'.' or c == u'\u3002' or \
574                 c == u'\uff0e' or c == u'\uff61':
575                if len(label) == 0:
576                    raise EmptyLabel
577                labels.append(encodings.idna.ToASCII(label))
578                label = u''
579            elif c == u'\\':
580                escaping = True
581                edigits = 0
582                total = 0
583            else:
584                label += c
585        if escaping:
586            raise BadEscape
587        if len(label) > 0:
588            labels.append(encodings.idna.ToASCII(label))
589        else:
590            labels.append('')
591    if (len(labels) == 0 or labels[-1] != '') and not origin is None:
592        labels.extend(list(origin.labels))
593    return Name(labels)
594
595def from_text(text, origin = root):
596    """Convert text into a Name object.
597    @rtype: dns.name.Name object
598    """
599
600    if not isinstance(text, str):
601        if isinstance(text, unicode) and sys.hexversion >= 0x02030000:
602            return from_unicode(text, origin)
603        else:
604            raise ValueError("input to from_text() must be a string")
605    if not (origin is None or isinstance(origin, Name)):
606        raise ValueError("origin must be a Name or None")
607    labels = []
608    label = ''
609    escaping = False
610    edigits = 0
611    total = 0
612    if text == '@':
613        text = ''
614    if text:
615        if text == '.':
616            return Name([''])
617        for c in text:
618            if escaping:
619                if edigits == 0:
620                    if c.isdigit():
621                        total = int(c)
622                        edigits += 1
623                    else:
624                        label += c
625                        escaping = False
626                else:
627                    if not c.isdigit():
628                        raise BadEscape
629                    total *= 10
630                    total += int(c)
631                    edigits += 1
632                    if edigits == 3:
633                        escaping = False
634                        label += chr(total)
635            elif c == '.':
636                if len(label) == 0:
637                    raise EmptyLabel
638                labels.append(label)
639                label = ''
640            elif c == '\\':
641                escaping = True
642                edigits = 0
643                total = 0
644            else:
645                label += c
646        if escaping:
647            raise BadEscape
648        if len(label) > 0:
649            labels.append(label)
650        else:
651            labels.append('')
652    if (len(labels) == 0 or labels[-1] != '') and not origin is None:
653        labels.extend(list(origin.labels))
654    return Name(labels)
655
656def from_wire(message, current):
657    """Convert possibly compressed wire format into a Name.
658    @param message: the entire DNS message
659    @type message: string
660    @param current: the offset of the beginning of the name from the start
661    of the message
662    @type current: int
663    @raises dns.name.BadPointer: a compression pointer did not point backwards
664    in the message
665    @raises dns.name.BadLabelType: an invalid label type was encountered.
666    @returns: a tuple consisting of the name that was read and the number
667    of bytes of the wire format message which were consumed reading it
668    @rtype: (dns.name.Name object, int) tuple
669    """
670
671    if not isinstance(message, str):
672        raise ValueError("input to from_wire() must be a byte string")
673    labels = []
674    biggest_pointer = current
675    hops = 0
676    count = ord(message[current])
677    current += 1
678    cused = 1
679    while count != 0:
680        if count < 64:
681            labels.append(message[current : current + count])
682            current += count
683            if hops == 0:
684                cused += count
685        elif count >= 192:
686            current = (count & 0x3f) * 256 + ord(message[current])
687            if hops == 0:
688                cused += 1
689            if current >= biggest_pointer:
690                raise BadPointer
691            biggest_pointer = current
692            hops += 1
693        else:
694            raise BadLabelType
695        count = ord(message[current])
696        current += 1
697        if hops == 0:
698            cused += 1
699    labels.append('')
700    return (Name(labels), cused)
701