1# Author: Trevor Perrin
2# See the LICENSE file for legal information regarding use of this file.
3
4"""Classes for reading/writing binary data (such as TLS records)."""
5
6from .compat import *
7
8class Writer(object):
9    def __init__(self):
10        self.bytes = bytearray(0)
11
12    def add(self, x, length):
13        self.bytes += bytearray(length)
14        newIndex = len(self.bytes) - 1
15        for count in range(length):
16            self.bytes[newIndex] = x & 0xFF
17            x >>= 8
18            newIndex -= 1
19
20    def addFixSeq(self, seq, length):
21        for e in seq:
22            self.add(e, length)
23
24    def addVarSeq(self, seq, length, lengthLength):
25        self.add(len(seq)*length, lengthLength)
26        for e in seq:
27            self.add(e, length)
28
29class Parser(object):
30    def __init__(self, bytes):
31        self.bytes = bytes
32        self.index = 0
33
34    def get(self, length):
35        if self.index + length > len(self.bytes):
36            raise SyntaxError()
37        x = 0
38        for count in range(length):
39            x <<= 8
40            x |= self.bytes[self.index]
41            self.index += 1
42        return x
43
44    def getFixBytes(self, lengthBytes):
45        bytes = self.bytes[self.index : self.index+lengthBytes]
46        self.index += lengthBytes
47        return bytes
48
49    def getVarBytes(self, lengthLength):
50        lengthBytes = self.get(lengthLength)
51        return self.getFixBytes(lengthBytes)
52
53    def getFixList(self, length, lengthList):
54        l = [0] * lengthList
55        for x in range(lengthList):
56            l[x] = self.get(length)
57        return l
58
59    def getVarList(self, length, lengthLength):
60        lengthList = self.get(lengthLength)
61        if lengthList % length != 0:
62            raise SyntaxError()
63        lengthList = lengthList // length
64        l = [0] * lengthList
65        for x in range(lengthList):
66            l[x] = self.get(length)
67        return l
68
69    def startLengthCheck(self, lengthLength):
70        self.lengthCheck = self.get(lengthLength)
71        self.indexCheck = self.index
72
73    def setLengthCheck(self, length):
74        self.lengthCheck = length
75        self.indexCheck = self.index
76
77    def stopLengthCheck(self):
78        if (self.index - self.indexCheck) != self.lengthCheck:
79            raise SyntaxError()
80
81    def atLengthCheck(self):
82        if (self.index - self.indexCheck) < self.lengthCheck:
83            return False
84        elif (self.index - self.indexCheck) == self.lengthCheck:
85            return True
86        else:
87            raise SyntaxError()
88