1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## Modified by Maxence Tury <maxence.tury@ssi.gouv.fr>
5## Acknowledgment: Ralph Broenink
6## This program is published under a GPLv2 license
7
8"""
9Basic Encoding Rules (BER) for ASN.1
10"""
11
12from __future__ import absolute_import
13from scapy.error import warning
14from scapy.compat import *
15from scapy.utils import binrepr,inet_aton,inet_ntoa
16from scapy.asn1.asn1 import ASN1_Decoding_Error,ASN1_Encoding_Error,ASN1_BadTag_Decoding_Error,ASN1_Codecs,ASN1_Class_UNIVERSAL,ASN1_Error,ASN1_DECODING_ERROR,ASN1_BADTAG
17import scapy.modules.six as six
18
19##################
20## BER encoding ##
21##################
22
23
24
25#####[ BER tools ]#####
26
27
28class BER_Exception(Exception):
29    pass
30
31class BER_Encoding_Error(ASN1_Encoding_Error):
32    def __init__(self, msg, encoded=None, remaining=None):
33        Exception.__init__(self, msg)
34        self.remaining = remaining
35        self.encoded = encoded
36    def __str__(self):
37        s = Exception.__str__(self)
38        if isinstance(self.encoded, BERcodec_Object):
39            s+="\n### Already encoded ###\n%s" % self.encoded.strshow()
40        else:
41            s+="\n### Already encoded ###\n%r" % self.encoded
42        s+="\n### Remaining ###\n%r" % self.remaining
43        return s
44
45class BER_Decoding_Error(ASN1_Decoding_Error):
46    def __init__(self, msg, decoded=None, remaining=None):
47        Exception.__init__(self, msg)
48        self.remaining = remaining
49        self.decoded = decoded
50    def __str__(self):
51        s = Exception.__str__(self)
52        if isinstance(self.decoded, BERcodec_Object):
53            s+="\n### Already decoded ###\n%s" % self.decoded.strshow()
54        else:
55            s+="\n### Already decoded ###\n%r" % self.decoded
56        s+="\n### Remaining ###\n%r" % self.remaining
57        return s
58
59class BER_BadTag_Decoding_Error(BER_Decoding_Error, ASN1_BadTag_Decoding_Error):
60    pass
61
62def BER_len_enc(l, size=0):
63        if l <= 127 and size==0:
64            return chb(l)
65        s = b""
66        while l or size>0:
67            s = chb(l&0xff)+s
68            l >>= 8
69            size -= 1
70        if len(s) > 127:
71            raise BER_Exception("BER_len_enc: Length too long (%i) to be encoded [%r]" % (len(s),s))
72        return chb(len(s)|0x80)+s
73def BER_len_dec(s):
74        l = orb(s[0])
75        if not l & 0x80:
76            return l,s[1:]
77        l &= 0x7f
78        if len(s) <= l:
79            raise BER_Decoding_Error("BER_len_dec: Got %i bytes while expecting %i" % (len(s)-1, l),remaining=s)
80        ll = 0
81        for c in s[1:l+1]:
82            ll <<= 8
83            ll |= orb(c)
84        return ll,s[l+1:]
85
86def BER_num_enc(l, size=1):
87        x=[]
88        while l or size>0:
89            x.insert(0, l & 0x7f)
90            if len(x) > 1:
91                x[0] |= 0x80
92            l >>= 7
93            size -= 1
94        return b"".join(chb(k) for k in x)
95def BER_num_dec(s, cls_id=0):
96        if len(s) == 0:
97            raise BER_Decoding_Error("BER_num_dec: got empty string", remaining=s)
98        x = cls_id
99        for i, c in enumerate(s):
100            c = orb(c)
101            x <<= 7
102            x |= c&0x7f
103            if not c&0x80:
104                break
105        if c&0x80:
106            raise BER_Decoding_Error("BER_num_dec: unfinished number description", remaining=s)
107        return x, s[i+1:]
108
109def BER_id_dec(s):
110    # This returns the tag ALONG WITH THE PADDED CLASS+CONSTRUCTIVE INFO.
111    # Let's recall that bits 8-7 from the first byte of the tag encode
112    # the class information, while bit 6 means primitive or constructive.
113    #
114    # For instance, with low-tag-number b'\x81', class would be 0b10
115    # ('context-specific') and tag 0x01, but we return 0x81 as a whole.
116    # For b'\xff\x22', class would be 0b11 ('private'), constructed, then
117    # padding, then tag 0x22, but we return (0xff>>5)*128^1 + 0x22*128^0.
118    # Why the 5-bit-shifting? Because it provides an unequivocal encoding
119    # on base 128 (note that 0xff would equal 1*128^1 + 127*128^0...),
120    # as we know that bits 5 to 1 are fixed to 1 anyway.
121    #
122    # As long as there is no class differentiation, we have to keep this info
123    # encoded in scapy's tag in order to reuse it for packet building.
124    # Note that tags thus may have to be hard-coded with their extended
125    # information, e.g. a SEQUENCE from asn1.py has a direct tag 0x20|16.
126        x = orb(s[0])
127        if x & 0x1f != 0x1f:
128            # low-tag-number
129            return x,s[1:]
130        else:
131            # high-tag-number
132            return BER_num_dec(s[1:], cls_id=x>>5)
133def BER_id_enc(n):
134        if n < 256:
135            # low-tag-number
136            return chb(n)
137        else:
138            # high-tag-number
139            s = BER_num_enc(n)
140            tag = orb(s[0])             # first byte, as an int
141            tag &= 0x07                 # reset every bit from 8 to 4
142            tag <<= 5                   # move back the info bits on top
143            tag |= 0x1f                 # pad with 1s every bit from 5 to 1
144            return chb(tag) + s[1:]
145
146# The functions below provide implicit and explicit tagging support.
147def BER_tagging_dec(s, hidden_tag=None, implicit_tag=None,
148                    explicit_tag=None, safe=False):
149    # We output the 'real_tag' if it is different from the (im|ex)plicit_tag.
150    real_tag = None
151    if len(s) > 0:
152        err_msg = "BER_tagging_dec: observed tag does not match expected tag"
153        if implicit_tag is not None:
154            ber_id,s = BER_id_dec(s)
155            if ber_id != implicit_tag:
156                if not safe:
157                    raise BER_Decoding_Error(err_msg, remaining=s)
158                else:
159                    real_tag = ber_id
160            s = chb(hash(hidden_tag)) + s
161        elif explicit_tag is not None:
162            ber_id,s = BER_id_dec(s)
163            if ber_id != explicit_tag:
164                if not safe:
165                    raise BER_Decoding_Error(err_msg, remaining=s)
166                else:
167                    real_tag = ber_id
168            l,s = BER_len_dec(s)
169    return real_tag, s
170def BER_tagging_enc(s, implicit_tag=None, explicit_tag=None):
171    if len(s) > 0:
172        if implicit_tag is not None:
173            s = BER_id_enc(implicit_tag) + s[1:]
174        elif explicit_tag is not None:
175            s = BER_id_enc(explicit_tag) + BER_len_enc(len(s)) + s
176    return s
177
178#####[ BER classes ]#####
179
180class BERcodec_metaclass(type):
181    def __new__(cls, name, bases, dct):
182        c = super(BERcodec_metaclass, cls).__new__(cls, name, bases, dct)
183        try:
184            c.tag.register(c.codec, c)
185        except:
186            warning("Error registering %r for %r" % (c.tag, c.codec))
187        return c
188
189
190class BERcodec_Object(six.with_metaclass(BERcodec_metaclass)):
191    codec = ASN1_Codecs.BER
192    tag = ASN1_Class_UNIVERSAL.ANY
193
194    @classmethod
195    def asn1_object(cls, val):
196        return cls.tag.asn1_object(val)
197
198    @classmethod
199    def check_string(cls, s):
200        if not s:
201            raise BER_Decoding_Error("%s: Got empty object while expecting tag %r" %
202                                     (cls.__name__,cls.tag), remaining=s)
203    @classmethod
204    def check_type(cls, s):
205        cls.check_string(s)
206        tag, remainder = BER_id_dec(s)
207        if cls.tag != tag:
208            raise BER_BadTag_Decoding_Error("%s: Got tag [%i/%#x] while expecting %r" %
209                                            (cls.__name__, tag, tag, cls.tag), remaining=s)
210        return remainder
211    @classmethod
212    def check_type_get_len(cls, s):
213        s2 = cls.check_type(s)
214        if not s2:
215            raise BER_Decoding_Error("%s: No bytes while expecting a length" %
216                                     cls.__name__, remaining=s)
217        return BER_len_dec(s2)
218    @classmethod
219    def check_type_check_len(cls, s):
220        l,s3 = cls.check_type_get_len(s)
221        if len(s3) < l:
222            raise BER_Decoding_Error("%s: Got %i bytes while expecting %i" %
223                                     (cls.__name__, len(s3), l), remaining=s)
224        return l,s3[:l],s3[l:]
225
226    @classmethod
227    def do_dec(cls, s, context=None, safe=False):
228        if context is None:
229            context = cls.tag.context
230        cls.check_string(s)
231        p,_ = BER_id_dec(s)
232        if p not in context:
233            t = s
234            if len(t) > 18:
235                t = t[:15]+b"..."
236            raise BER_Decoding_Error("Unknown prefix [%02x] for [%r]" % (p,t), remaining=s)
237        codec = context[p].get_codec(ASN1_Codecs.BER)
238        return codec.dec(s,context,safe)
239
240    @classmethod
241    def dec(cls, s, context=None, safe=False):
242        if not safe:
243            return cls.do_dec(s, context, safe)
244        try:
245            return cls.do_dec(s, context, safe)
246        except BER_BadTag_Decoding_Error as e:
247            o,remain = BERcodec_Object.dec(e.remaining, context, safe)
248            return ASN1_BADTAG(o),remain
249        except BER_Decoding_Error as e:
250            return ASN1_DECODING_ERROR(s, exc=e),""
251        except ASN1_Error as e:
252            return ASN1_DECODING_ERROR(s, exc=e),""
253
254    @classmethod
255    def safedec(cls, s, context=None):
256        return cls.dec(s, context, safe=True)
257
258
259    @classmethod
260    def enc(cls, s):
261        if isinstance(s, six.string_types):
262            return BERcodec_STRING.enc(s)
263        else:
264            return BERcodec_INTEGER.enc(int(s))
265
266ASN1_Codecs.BER.register_stem(BERcodec_Object)
267
268
269##########################
270#### BERcodec objects ####
271##########################
272
273class BERcodec_INTEGER(BERcodec_Object):
274    tag = ASN1_Class_UNIVERSAL.INTEGER
275    @classmethod
276    def enc(cls, i):
277        s = []
278        while True:
279            s.append(i&0xff)
280            if -127 <= i < 0:
281                break
282            if 128 <= i <= 255:
283                s.append(0)
284            i >>= 8
285            if not i:
286                break
287        s = [chb(hash(c)) for c in s]
288        s.append(BER_len_enc(len(s)))
289        s.append(chb(hash(cls.tag)))
290        s.reverse()
291        return b"".join(s)
292    @classmethod
293    def do_dec(cls, s, context=None, safe=False):
294        l,s,t = cls.check_type_check_len(s)
295        x = 0
296        if s:
297            if orb(s[0])&0x80: # negative int
298                x = -1
299            for c in s:
300                x <<= 8
301                x |= orb(c)
302        return cls.asn1_object(x),t
303
304class BERcodec_BOOLEAN(BERcodec_INTEGER):
305    tag = ASN1_Class_UNIVERSAL.BOOLEAN
306
307class BERcodec_BIT_STRING(BERcodec_Object):
308    tag = ASN1_Class_UNIVERSAL.BIT_STRING
309    @classmethod
310    def do_dec(cls, s, context=None, safe=False):
311        # /!\ the unused_bits information is lost after this decoding
312        l,s,t = cls.check_type_check_len(s)
313        if len(s) > 0:
314            unused_bits = orb(s[0])
315            if safe and unused_bits > 7:
316                raise BER_Decoding_Error("BERcodec_BIT_STRING: too many unused_bits advertised", remaining=s)
317            s = "".join(binrepr(orb(x)).zfill(8) for x in s[1:])
318            if unused_bits > 0:
319                s = s[:-unused_bits]
320            return cls.tag.asn1_object(s),t
321        else:
322            raise BER_Decoding_Error("BERcodec_BIT_STRING found no content (not even unused_bits byte)", remaining=s)
323    @classmethod
324    def enc(cls,s):
325        # /!\ this is DER encoding (bit strings are only zero-bit padded)
326        s = raw(s)
327        if len(s) % 8 == 0:
328            unused_bits = 0
329        else:
330            unused_bits = 8 - len(s)%8
331            s += b"0"*unused_bits
332        s = b"".join(chb(int(b"".join(chb(y) for y in x),2)) for x in zip(*[iter(s)]*8))
333        s = chb(unused_bits) + s
334        return chb(hash(cls.tag))+BER_len_enc(len(s))+s
335
336class BERcodec_STRING(BERcodec_Object):
337    tag = ASN1_Class_UNIVERSAL.STRING
338    @classmethod
339    def enc(cls,s):
340        s = raw(s)
341        return chb(hash(cls.tag))+BER_len_enc(len(s))+s  # Be sure we are encoding bytes
342    @classmethod
343    def do_dec(cls, s, context=None, safe=False):
344        l,s,t = cls.check_type_check_len(s)
345        return cls.tag.asn1_object(s),t
346
347class BERcodec_NULL(BERcodec_INTEGER):
348    tag = ASN1_Class_UNIVERSAL.NULL
349    @classmethod
350    def enc(cls, i):
351        if i == 0:
352            return chb(hash(cls.tag))+b"\0"
353        else:
354            return super(cls,cls).enc(i)
355
356class BERcodec_OID(BERcodec_Object):
357    tag = ASN1_Class_UNIVERSAL.OID
358    @classmethod
359    def enc(cls, oid):
360        oid = raw(oid)
361        lst = [int(x) for x in oid.strip(b".").split(b".")]
362        if len(lst) >= 2:
363            lst[1] += 40*lst[0]
364            del(lst[0])
365        s = b"".join(BER_num_enc(k) for k in lst)
366        return chb(hash(cls.tag))+BER_len_enc(len(s))+s
367    @classmethod
368    def do_dec(cls, s, context=None, safe=False):
369        l,s,t = cls.check_type_check_len(s)
370        lst = []
371        while s:
372            l,s = BER_num_dec(s)
373            lst.append(l)
374        if (len(lst) > 0):
375            lst.insert(0,lst[0]//40)
376            lst[1] %= 40
377        return cls.asn1_object(b".".join(str(k).encode('ascii') for k in lst)), t
378
379class BERcodec_ENUMERATED(BERcodec_INTEGER):
380    tag = ASN1_Class_UNIVERSAL.ENUMERATED
381
382class BERcodec_UTF8_STRING(BERcodec_STRING):
383    tag = ASN1_Class_UNIVERSAL.UTF8_STRING
384
385class BERcodec_NUMERIC_STRING(BERcodec_STRING):
386    tag = ASN1_Class_UNIVERSAL.NUMERIC_STRING
387
388class BERcodec_PRINTABLE_STRING(BERcodec_STRING):
389    tag = ASN1_Class_UNIVERSAL.PRINTABLE_STRING
390
391class BERcodec_T61_STRING(BERcodec_STRING):
392    tag = ASN1_Class_UNIVERSAL.T61_STRING
393
394class BERcodec_VIDEOTEX_STRING(BERcodec_STRING):
395    tag = ASN1_Class_UNIVERSAL.VIDEOTEX_STRING
396
397class BERcodec_IA5_STRING(BERcodec_STRING):
398    tag = ASN1_Class_UNIVERSAL.IA5_STRING
399
400class BERcodec_UTC_TIME(BERcodec_STRING):
401    tag = ASN1_Class_UNIVERSAL.UTC_TIME
402
403class BERcodec_GENERALIZED_TIME(BERcodec_STRING):
404    tag = ASN1_Class_UNIVERSAL.GENERALIZED_TIME
405
406class BERcodec_ISO646_STRING(BERcodec_STRING):
407    tag = ASN1_Class_UNIVERSAL.ISO646_STRING
408
409class BERcodec_UNIVERSAL_STRING(BERcodec_STRING):
410    tag = ASN1_Class_UNIVERSAL.UNIVERSAL_STRING
411
412class BERcodec_BMP_STRING(BERcodec_STRING):
413    tag = ASN1_Class_UNIVERSAL.BMP_STRING
414
415class BERcodec_SEQUENCE(BERcodec_Object):
416    tag = ASN1_Class_UNIVERSAL.SEQUENCE
417    @classmethod
418    def enc(cls, l):
419        if not isinstance(l, bytes):
420            l = b"".join(x.enc(cls.codec) for x in l)
421        return chb(hash(cls.tag))+BER_len_enc(len(l))+l
422    @classmethod
423    def do_dec(cls, s, context=None, safe=False):
424        if context is None:
425            context = cls.tag.context
426        l,st = cls.check_type_get_len(s) # we may have len(s) < l
427        s,t = st[:l],st[l:]
428        obj = []
429        while s:
430            try:
431                o,s = BERcodec_Object.dec(s, context, safe)
432            except BER_Decoding_Error as err:
433                err.remaining += t
434                if err.decoded is not None:
435                    obj.append(err.decoded)
436                err.decoded = obj
437                raise
438            obj.append(o)
439        if len(st) < l:
440            raise BER_Decoding_Error("Not enough bytes to decode sequence", decoded=obj)
441        return cls.asn1_object(obj),t
442
443class BERcodec_SET(BERcodec_SEQUENCE):
444    tag = ASN1_Class_UNIVERSAL.SET
445
446class BERcodec_IPADDRESS(BERcodec_STRING):
447    tag = ASN1_Class_UNIVERSAL.IPADDRESS
448    @classmethod
449    def enc(cls, ipaddr_ascii):
450        try:
451            s = inet_aton(ipaddr_ascii)
452        except Exception:
453            raise BER_Encoding_Error("IPv4 address could not be encoded")
454        return chb(hash(cls.tag))+BER_len_enc(len(s))+s
455    @classmethod
456    def do_dec(cls, s, context=None, safe=False):
457        l,s,t = cls.check_type_check_len(s)
458        try:
459            ipaddr_ascii = inet_ntoa(s)
460        except Exception:
461            raise BER_Decoding_Error("IP address could not be decoded", remaining=s)
462        return cls.asn1_object(ipaddr_ascii), t
463
464class BERcodec_COUNTER32(BERcodec_INTEGER):
465    tag = ASN1_Class_UNIVERSAL.COUNTER32
466
467class BERcodec_GAUGE32(BERcodec_INTEGER):
468    tag = ASN1_Class_UNIVERSAL.GAUGE32
469
470class BERcodec_TIME_TICKS(BERcodec_INTEGER):
471    tag = ASN1_Class_UNIVERSAL.TIME_TICKS
472