1# Copyright 2015 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""
5MBIM Data transfer module is responsible for generating valid MBIM NTB frames
6from  IP packets and for extracting IP packets from received MBIM NTB frames.
7
8"""
9import array
10import struct
11from collections import namedtuple
12
13from autotest_lib.client.cros.cellular.mbim_compliance import mbim_constants
14from autotest_lib.client.cros.cellular.mbim_compliance \
15        import mbim_data_channel
16from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors
17
18
19NTH_SIGNATURE_32 = 0x686D636E  # "ncmh"
20NDP_SIGNATURE_IPS_32 = 0x00737069  # "ips0"
21NDP_SIGNATURE_DSS_32 = 0x00737364  # "dss0"
22
23NTH_SIGNATURE_16 = 0x484D434E  # "NCMH"
24NDP_SIGNATURE_IPS_16 = 0x00535049  # "IPS0"
25NDP_SIGNATURE_DSS_16 = 0x00535344  # "DSS0"
26
27class MBIMDataTransfer(object):
28    """
29    MBIMDataTransfer class is the public interface for any data transfer
30    from/to the device via the MBIM data endpoints (BULK-IN/BULK-OUT).
31
32    The class encapsulates the MBIM NTB frame generation/parsing as well as
33    sending the the NTB frames to the device and vice versa.
34    Users are expected to:
35    1. Initialize the channel data transfer module by providing a valid
36    device context which holds all the required info regarding the devie under
37    test.
38    2. Use send_data_packets to send IP packets to the device.
39    3. Use receive_data_packets to receive IP packets from the device.
40
41    """
42    def __init__(self, device_context):
43        """
44        Initialize the Data Transfer object. The data transfer object
45        instantiates the data channel to prepare for any data transfer from/to
46        the device using the bulk pipes.
47
48        @params device_context: The device context which contains all the USB
49                descriptors, NTB params and USB handle to the device.
50
51        """
52        self._device_context = device_context
53        mbim_data_interface = (
54                device_context.descriptor_cache.mbim_data_interface)
55        bulk_in_endpoint = (
56                device_context.descriptor_cache.bulk_in_endpoint)
57        bulk_out_endpoint = (
58                device_context.descriptor_cache.bulk_out_endpoint)
59        self._data_channel = mbim_data_channel.MBIMDataChannel(
60                device=device_context.device,
61                data_interface_number=mbim_data_interface.bInterfaceNumber,
62                bulk_in_endpoint_address=bulk_in_endpoint.bEndpointAddress,
63                bulk_out_endpoint_address=bulk_out_endpoint.bEndpointAddress,
64                max_in_buffer_size=device_context.max_in_data_transfer_size)
65
66
67    def send_data_packets(self, ntb_format, data_packets):
68        """
69        Creates an MBIM frame for the payload provided and sends it out to the
70        device using bulk out pipe.
71
72        @param ntb_format: Whether to send an NTB16 or NTB32 frame.
73        @param data_packets: Array of data packets. Each packet is a byte array
74                corresponding to the IP packet or any other payload to be sent.
75
76        """
77        ntb_object = MBIMNtb(ntb_format)
78        ntb_frame = ntb_object.generate_ntb(
79                data_packets,
80                self._device_context.max_out_data_transfer_size,
81                self._device_context.out_data_transfer_divisor,
82                self._device_context.out_data_transfer_payload_remainder,
83                self._device_context.out_data_transfer_ndp_alignment)
84        self._data_channel.send_ntb(ntb_frame)
85
86
87    def receive_data_packets(self, ntb_format):
88        """
89        Receives an MBIM frame from the device using the bulk in pipe,
90        deaggregates the payload from the frame and returns it to the caller.
91
92        Will return an empty tuple, if no frame is received from the device.
93
94        @param ntb_format: Whether to receive an NTB16 or NTB32 frame.
95        @returns tuple of (nth, ndp, ndp_entries, payload) where,
96                nth - NTH header object received.
97                ndp - NDP header object received.
98                ndp_entries - Array of NDP entry header objects.
99                payload - Array of packets where each packet is a byte array.
100
101        """
102        ntb_frame = self._data_channel.receive_ntb()
103        if not ntb_frame:
104            return ()
105        ntb_object = MBIMNtb(ntb_format)
106        return ntb_object.parse_ntb(ntb_frame)
107
108
109class MBIMNtb(object):
110    """
111    MBIM NTB class used for MBIM data transfer.
112
113    This class is used to generate/parse NTB frames.
114
115    Limitations:
116    1. We currently only support a single NDP frame within an NTB.
117    2. We only support IP data payload. This can be overcome by using the DSS
118            (instead of IPS) prefix in NDP signature if required.
119
120    """
121    _NEXT_SEQUENCE_NUMBER = 0
122
123    def __init__(self, ntb_format):
124        """
125        Initialization of the NTB object.
126
127        We assign the appropriate header classes required based on whether
128        we are going to work with NTB16 or NTB32 data frames.
129
130        @param ntb_format: Type of NTB: 16 vs 32
131
132        """
133        self._ntb_format = ntb_format
134        # Defining the tuples to be used for the headers.
135        if ntb_format == mbim_constants.NTB_FORMAT_16:
136            self._nth_class = Nth16
137            self._ndp_class = Ndp16
138            self._ndp_entry_class = NdpEntry16
139            self._nth_signature = NTH_SIGNATURE_16
140            self._ndp_signature = NDP_SIGNATURE_IPS_16
141        else:
142            self._nth_class = Nth32
143            self._ndp_class = Ndp32
144            self._ndp_entry_class = NdpEntry32
145            self._nth_signature = NTH_SIGNATURE_32
146            self._ndp_signature = NDP_SIGNATURE_IPS_32
147
148
149    @classmethod
150    def get_next_sequence_number(cls):
151        """
152        Returns incrementing sequence numbers on successive calls. We start
153        the sequence numbering at 0.
154
155        @returns The sequence number for data transfers.
156
157        """
158        # Make sure to rollover the 16 bit sequence number.
159        if MBIMNtb._NEXT_SEQUENCE_NUMBER > (0xFFFF - 2):
160            MBIMNtb._NEXT_SEQUENCE_NUMBER = 0x0000
161        sequence_number = MBIMNtb._NEXT_SEQUENCE_NUMBER
162        MBIMNtb._NEXT_SEQUENCE_NUMBER += 1
163        return sequence_number
164
165
166    @classmethod
167    def reset_sequence_number(cls):
168        """
169        Resets the sequence number to be used for NTB's sent from host. This
170        has to be done every time the device is reset.
171
172        """
173        cls._NEXT_SEQUENCE_NUMBER = 0x00000000
174
175
176    def get_next_payload_offset(self,
177                                current_offset,
178                                ntb_divisor,
179                                ntb_payload_remainder):
180        """
181        Helper function to find the offset to place the next payload
182
183        Alignment of payloads follow this formula:
184            Offset % ntb_divisor == ntb_payload_remainder.
185
186        @params current_offset: Current index offset in the frame.
187        @param ntb_divisor: Used for payload alignment within the frame.
188        @param ntb_payload_remainder: Used for payload alignment within the
189                frame.
190        @returns offset to place the next payload at.
191
192        """
193        next_payload_offset = (
194                (((current_offset + (ntb_divisor - 1)) / ntb_divisor) *
195                 ntb_divisor) + ntb_payload_remainder)
196        return next_payload_offset
197
198
199    def generate_ntb(self,
200                     payload,
201                     max_ntb_size,
202                     ntb_divisor,
203                     ntb_payload_remainder,
204                     ntb_ndp_alignment):
205        """
206        This function generates an NTB frame out of the payload provided.
207
208        @param payload: Array of packets to sent to the device. Each packet
209                contains the raw byte array of IP packet to be sent.
210        @param max_ntb_size: Max size of NTB frame supported by the device.
211        @param ntb_divisor: Used for payload alignment within the frame.
212        @param ntb_payload_remainder: Used for payload alignment within the
213                frame.
214        @param ntb_ndp_alignment : Used for NDP header alignment within the
215                frame.
216        @raises MBIMComplianceNtbError if the complete |ntb| can not fit into
217                |max_ntb_size|.
218        @returns the raw MBIM NTB byte array.
219
220        """
221        cls = self.__class__
222
223        # We start with the NTH header, then the payload and then finally
224        # the NDP header and the associated NDP entries.
225        ntb_curr_offset = self._nth_class.get_struct_len()
226        num_packets = len(payload)
227        nth_length = self._nth_class.get_struct_len()
228        ndp_length = self._ndp_class.get_struct_len()
229        # We need one extra ZLP NDP entry at the end, so account for it.
230        ndp_entries_length = (
231                self._ndp_entry_class.get_struct_len() * (num_packets + 1))
232
233        # Create the NDP header and an NDP_ENTRY header for each packet.
234        # We can create the NTH header only after we calculate the total length.
235        self.ndp = self._ndp_class(
236                signature=self._ndp_signature,
237                length=ndp_length+ndp_entries_length,
238                next_ndp_index=0)
239        self.ndp_entries = []
240
241        # We'll also construct the payload raw data as we loop thru the packets.
242        # The padding in between the payload is added in place.
243        raw_ntb_frame_payload = array.array('B', [])
244        for packet in payload:
245            offset = self.get_next_payload_offset(
246                    ntb_curr_offset, ntb_divisor, ntb_payload_remainder)
247            align_length = offset - ntb_curr_offset
248            length = len(packet)
249            # Add align zeroes, then payload, then pad zeroes
250            raw_ntb_frame_payload += array.array('B', [0] * align_length)
251            raw_ntb_frame_payload += packet
252            self.ndp_entries.append(self._ndp_entry_class(
253                    datagram_index=offset, datagram_length=length))
254            ntb_curr_offset = offset + length
255
256        # Add the ZLP entry
257        self.ndp_entries.append(self._ndp_entry_class(
258                datagram_index=0, datagram_length=0))
259
260        # Store the NDP offset to be used in creating NTH header.
261        # NDP alignment is specified by the device with a minimum of 4 and it
262        # always a multiple of 2.
263        ndp_align_mask = ntb_ndp_alignment - 1
264        if ntb_curr_offset & ndp_align_mask:
265            pad_length = ntb_ndp_alignment - (ntb_curr_offset & ndp_align_mask)
266            raw_ntb_frame_payload += array.array('B', [0] * pad_length)
267            ntb_curr_offset += pad_length
268        ndp_offset = ntb_curr_offset
269        ntb_curr_offset += ndp_length
270        ntb_curr_offset += ndp_entries_length
271        if ntb_curr_offset > max_ntb_size:
272            mbim_errors.log_and_raise(
273                    mbim_errors.MBIMComplianceNtbError,
274                    'Could not fit the complete NTB of size %d into %d bytes' %
275                    ntb_curr_offset, max_ntb_size)
276        # Now create the NTH header
277        self.nth = self._nth_class(
278                signature=self._nth_signature,
279                header_length=nth_length,
280                sequence_number=cls.get_next_sequence_number(),
281                block_length=ntb_curr_offset,
282                fp_index=ndp_offset)
283
284        # Create the raw bytes now, we create the raw bytes of the header and
285        # attach it to the payload raw bytes with padding already created above.
286        raw_ntb_frame = array.array('B', [])
287        raw_ntb_frame += array.array('B', self.nth.pack())
288        raw_ntb_frame += raw_ntb_frame_payload
289        raw_ntb_frame += array.array('B', self.ndp.pack())
290        for entry in self.ndp_entries:
291            raw_ntb_frame += array.array('B', entry.pack())
292
293        self.payload = payload
294        self.raw_ntb_frame = raw_ntb_frame
295
296        return raw_ntb_frame
297
298
299    def parse_ntb(self, raw_ntb_frame):
300        """
301        This function parses an NTB frame and returns the NTH header, NDP header
302        and the payload parsed which can be used to inspect the response
303        from the device.
304
305        @param raw_ntb_frame: Array of bytes of an MBIM NTB frame.
306        @raises MBIMComplianceNtbError if there is an error in parsing.
307        @returns tuple of (nth, ndp, ndp_entries, payload) where,
308                nth - NTH header object received.
309                ndp - NDP header object received.
310                ndp_entries - Array of NDP entry header objects.
311                payload - Array of packets where each packet is a byte array.
312
313        """
314        # Read the nth header to find the ndp header index
315        self.nth = self._nth_class(raw_data=raw_ntb_frame)
316        ndp_offset = self.nth.fp_index
317        # Verify the total length field
318        if len(raw_ntb_frame) != self.nth.block_length:
319            mbim_errors.log_and_raise(
320                    mbim_errors.MBIMComplianceNtbError,
321                    'NTB size mismatch Total length: %x Reported: %x bytes' % (
322                            len(raw_ntb_frame), self.nth.block_length))
323
324        # Read the NDP header to find the number of packets in the entry
325        self.ndp = self._ndp_class(raw_data=raw_ntb_frame[ndp_offset:])
326        num_ndp_entries = (
327               (self.ndp.length - self._ndp_class.get_struct_len()) /
328               self._ndp_entry_class.get_struct_len())
329        ndp_entries_offset = ndp_offset + self._ndp_class.get_struct_len()
330        self.payload = []
331        self.ndp_entries = []
332        for _ in range(0, num_ndp_entries):
333            ndp_entry = self._ndp_entry_class(
334                   raw_data=raw_ntb_frame[ndp_entries_offset:])
335            ndp_entries_offset += self._ndp_entry_class.get_struct_len()
336            packet_start_offset = ndp_entry.datagram_index
337            packet_end_offset = (
338                   ndp_entry.datagram_index + ndp_entry.datagram_length)
339            # There is one extra ZLP NDP entry at the end, so account for it.
340            if ndp_entry.datagram_index and ndp_entry.datagram_length:
341                packet = array.array('B', raw_ntb_frame[packet_start_offset:
342                                                        packet_end_offset])
343                self.payload.append(packet)
344            self.ndp_entries.append(ndp_entry)
345
346        self.raw_ntb_frame = raw_ntb_frame
347
348        return (self.nth, self.ndp, self.ndp_entries, self.payload)
349
350
351def header_class_new(cls, **kwargs):
352    """
353    Creates a header instance with either the given field name/value
354    pairs or raw data buffer.
355
356    @param kwargs: Dictionary of (field_name, field_value) pairs or
357            raw_data=Packed binary array.
358    @returns New header object created.
359
360    """
361    field_values = []
362    if 'raw_data' in kwargs and kwargs['raw_data']:
363        raw_data = kwargs['raw_data']
364        data_format = cls.get_field_format_string()
365        unpack_length = cls.get_struct_len()
366        data_length = len(raw_data)
367        if data_length < unpack_length:
368            mbim_errors.log_and_raise(
369                    mbim_errors.MBIMComplianceDataTransferError,
370                    'Length of Data (%d) to be parsed less than header'
371                    ' structure length (%d)' %
372                    (data_length, unpack_length))
373        field_values = struct.unpack_from(data_format, raw_data)
374    else:
375        field_names = cls.get_field_names()
376        for field_name in field_names:
377            if field_name not in kwargs:
378                field_value = 0
379                field_values.append(field_value)
380            else:
381                field_values.append(kwargs.pop(field_name))
382        if kwargs:
383            mbim_errors.log_and_raise(
384                    mbim_errors.MBIMComplianceDataTransferError,
385                    'Unexpected fields (%s) in %s' % (
386                            kwargs.keys(), cls.__name__))
387    obj = super(cls, cls).__new__(cls, *field_values)
388    return obj
389
390
391class MBIMNtbHeadersMeta(type):
392    """
393    Metaclass for all the NTB headers. This is relatively toned down metaclass
394    to create namedtuples out of the header fields.
395
396    Header definition attributes:
397    _FIELDS: Used to define structure elements. Each element contains a format
398            specifier and the field name.
399
400    """
401    def __new__(mcs, name, bases, attrs):
402        if object in bases:
403            return super(MBIMNtbHeadersMeta, mcs).__new__(
404                    mcs, name, bases, attrs)
405        fields = attrs['_FIELDS']
406        if not fields:
407            mbim_errors.log_and_raise(
408                    mbim_errors.MBIMComplianceDataTransfer,
409                    '%s header must have some fields defined' % name)
410        _, field_names = zip(*fields)
411        attrs['__new__'] = header_class_new
412        header_class = namedtuple(name, field_names)
413        # Prepend the class created via namedtuple to |bases| in order to
414        # correctly resolve the __new__ method while preserving the class
415        # hierarchy.
416        cls = super(MBIMNtbHeadersMeta, mcs).__new__(
417                mcs, name, (header_class,) + bases, attrs)
418        return cls
419
420
421class MBIMNtbHeaders(object):
422    """
423    Base class for all NTB headers.
424
425    This class should not be instantiated on it's own.
426
427    The base class overrides namedtuple's __new__ to:
428    1. Create a tuple out of raw object.
429    2. Put value of zero for fields which are not specified by the caller,
430        For ex: reserved fields
431
432    """
433    __metaclass__ = MBIMNtbHeadersMeta
434
435    @classmethod
436    def get_fields(cls):
437        """
438        Helper function to find all the fields of this class.
439
440        @returns Fields of the structure.
441
442        """
443        return cls._FIELDS
444
445
446    @classmethod
447    def get_field_names(cls):
448        """
449        Helper function to return the field names of the header.
450
451        @returns The field names of the header structure.
452
453        """
454        _, field_names = zip(*cls.get_fields())
455        return field_names
456
457
458    @classmethod
459    def get_field_formats(cls):
460        """
461        Helper function to return the field formats of the header.
462
463        @returns The format of fields of the header structure.
464
465        """
466        field_formats, _ = zip(*cls.get_fields())
467        return field_formats
468
469
470    @classmethod
471    def get_field_format_string(cls):
472        """
473        Helper function to return the field format string of the header.
474
475        @returns The format string of the header structure.
476
477        """
478        format_string = '<' + ''.join(cls.get_field_formats())
479        return format_string
480
481
482    @classmethod
483    def get_struct_len(cls):
484        """
485        Returns the length of the structure representing the header.
486
487        @returns Length of the structure.
488
489        """
490        return struct.calcsize(cls.get_field_format_string())
491
492
493    def pack(self):
494        """
495        Packs a header based on the field format specified.
496
497        @returns The packet in binary array form.
498
499        """
500        cls = self.__class__
501        field_names = cls.get_field_names()
502        format_string = cls.get_field_format_string()
503        field_values = [getattr(self, name) for name in field_names]
504        return array.array('B', struct.pack(format_string, *field_values))
505
506
507class Nth16(MBIMNtbHeaders):
508    """ The class for MBIM NTH16 objects. """
509    _FIELDS = (('I', 'signature'),
510               ('H', 'header_length'),
511               ('H', 'sequence_number'),
512               ('H', 'block_length'),
513               ('H', 'fp_index'))
514
515
516class Ndp16(MBIMNtbHeaders):
517    """ The class for MBIM NDP16 objects. """
518    _FIELDS = (('I', 'signature'),
519               ('H', 'length'),
520               ('H', 'next_ndp_index'))
521
522
523class NdpEntry16(MBIMNtbHeaders):
524    """ The class for MBIM NDP16 objects. """
525    _FIELDS = (('H', 'datagram_index'),
526               ('H', 'datagram_length'))
527
528
529class Nth32(MBIMNtbHeaders):
530    """ The class for MBIM NTH32 objects. """
531    _FIELDS = (('I', 'signature'),
532               ('H', 'header_length'),
533               ('H', 'sequence_number'),
534               ('I', 'block_length'),
535               ('I', 'fp_index'))
536
537
538class Ndp32(MBIMNtbHeaders):
539    """ The class for MBIM NTH32 objects. """
540    _FIELDS = (('I', 'signature'),
541               ('H', 'length'),
542               ('H', 'reserved_6'),
543               ('I', 'next_ndp_index'),
544               ('I', 'reserved_12'))
545
546
547class NdpEntry32(MBIMNtbHeaders):
548    """ The class for MBIM NTH32 objects. """
549    _FIELDS = (('I', 'datagram_index'),
550               ('I', 'datagram_length'))
551
552