1# Copyright 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simple module for declaring C-like structures.
16
17Example usage:
18
19>>> # Declare a struct type by specifying name, field formats and field names.
20... # Field formats are the same as those used in the struct module.
21... import cstruct
22>>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
23>>>
24>>>
25>>> # Create instances from tuples or raw bytes. Data past the end is ignored.
26... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
27>>> print n1
28NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
29>>>
30>>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
31...               "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
32>>> print n2
33NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
34>>>
35>>> # Serialize to raw bytes.
36... print n1.Pack().encode("hex")
372c0000002000020000000000eb010000
38>>>
39>>> # Parse the beginning of a byte stream as a struct, and return the struct
40... # and the remainder of the stream for further reading.
41... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
42...         "\x00\x00\x00\x00\xfe\x01\x00\x00"
43...         "more data")
44>>> cstruct.Read(data, NLMsgHdr)
45(NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
46>>>
47"""
48
49import ctypes
50import struct
51
52
53def Struct(name, fmt, fields):
54  """Function that returns struct classes."""
55
56  class Meta(type):
57
58    def __len__(cls):
59      return cls._length
60
61    def __init__(cls, unused_name, unused_bases, namespace):
62      # Make the class object have the name that's passed in.
63      type.__init__(cls, namespace["_name"], unused_bases, namespace)
64
65  class CStruct(object):
66    """Class representing a C-like structure."""
67
68    __metaclass__ = Meta
69
70    _name = name
71    _format = fmt
72    _fields = fields
73
74    _length = struct.calcsize(_format)
75    if isinstance(_fields, str):
76      _fields = _fields.split(" ")
77
78    def _SetValues(self, values):
79      super(CStruct, self).__setattr__("_values", list(values))
80
81    def _Parse(self, data):
82      data = data[:self._length]
83      values = list(struct.unpack(self._format, data))
84      self._SetValues(values)
85
86    def __init__(self, values):
87      # Initializing from a string.
88      if isinstance(values, str):
89        if len(values) < self._length:
90          raise TypeError("%s requires string of length %d, got %d" %
91                          (self._name, self._length, len(values)))
92        self._Parse(values)
93      else:
94        # Initializing from a tuple.
95        if len(values) != len(self._fields):
96          raise TypeError("%s has exactly %d fields (%d given)" %
97                          (self._name, len(self._fields), len(values)))
98        self._SetValues(values)
99
100    def _FieldIndex(self, attr):
101      try:
102        return self._fields.index(attr)
103      except ValueError:
104        raise AttributeError("'%s' has no attribute '%s'" %
105                             (self._name, attr))
106
107    def __getattr__(self, name):
108      return self._values[self._FieldIndex(name)]
109
110    def __setattr__(self, name, value):
111      self._values[self._FieldIndex(name)] = value
112
113    @classmethod
114    def __len__(cls):
115      return cls._length
116
117    def Pack(self):
118      return struct.pack(self._format, *self._values)
119
120    def __str__(self):
121      return "%s(%s)" % (self._name, ", ".join(
122          "%s=%s" % (i, v) for i, v in zip(self._fields, self._values)))
123
124    def __repr__(self):
125      return str(self)
126
127    def CPointer(self):
128      """Returns a C pointer to the serialized structure."""
129      buf = ctypes.create_string_buffer(self.Pack())
130      # Store the C buffer in the object so it doesn't get garbage collected.
131      super(CStruct, self).__setattr__("_buffer", buf)
132      return ctypes.addressof(self._buffer)
133
134  return CStruct
135
136
137def Read(data, struct_type):
138  length = len(struct_type)
139  return struct_type(data), data[length:]
140