1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "MtpDataPacket"
18
19#include "MtpDataPacket.h"
20
21#include <algorithm>
22#include <fcntl.h>
23#include <stdio.h>
24#include <sys/types.h>
25#include <usbhost/usbhost.h>
26#include "MtpStringBuffer.h"
27#include "IMtpHandle.h"
28
29namespace android {
30
31namespace {
32// Reads the exact |count| bytes from |fd| to |buf|.
33// Returns |count| if it succeed to read the bytes. Otherwise returns -1. If it reaches EOF, the
34// function regards it as an error.
35ssize_t readExactBytes(int fd, void* buf, size_t count) {
36    if (count > SSIZE_MAX) {
37        return -1;
38    }
39    size_t read_count = 0;
40    while (read_count < count) {
41        int result = read(fd, static_cast<int8_t*>(buf) + read_count, count - read_count);
42        // Assume that EOF is error.
43        if (result <= 0) {
44            return -1;
45        }
46        read_count += result;
47    }
48    return read_count == count ? count : -1;
49}
50}  // namespace
51
52MtpDataPacket::MtpDataPacket()
53    :   MtpPacket(MTP_BUFFER_SIZE),   // MAX_USBFS_BUFFER_SIZE
54        mOffset(MTP_CONTAINER_HEADER_SIZE)
55{
56}
57
58MtpDataPacket::~MtpDataPacket() {
59}
60
61void MtpDataPacket::reset() {
62    MtpPacket::reset();
63    mOffset = MTP_CONTAINER_HEADER_SIZE;
64}
65
66void MtpDataPacket::setOperationCode(MtpOperationCode code) {
67    MtpPacket::putUInt16(MTP_CONTAINER_CODE_OFFSET, code);
68}
69
70void MtpDataPacket::setTransactionID(MtpTransactionID id) {
71    MtpPacket::putUInt32(MTP_CONTAINER_TRANSACTION_ID_OFFSET, id);
72}
73
74bool MtpDataPacket::getUInt8(uint8_t& value) {
75    if (mPacketSize - mOffset < sizeof(value))
76        return false;
77    value = mBuffer[mOffset++];
78    return true;
79}
80
81bool MtpDataPacket::getUInt16(uint16_t& value) {
82    if (mPacketSize - mOffset < sizeof(value))
83        return false;
84    int offset = mOffset;
85    value = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8);
86    mOffset += sizeof(value);
87    return true;
88}
89
90bool MtpDataPacket::getUInt32(uint32_t& value) {
91    if (mPacketSize - mOffset < sizeof(value))
92        return false;
93    int offset = mOffset;
94    value = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) |
95           ((uint32_t)mBuffer[offset + 2] << 16)  | ((uint32_t)mBuffer[offset + 3] << 24);
96    mOffset += sizeof(value);
97    return true;
98}
99
100bool MtpDataPacket::getUInt64(uint64_t& value) {
101    if (mPacketSize - mOffset < sizeof(value))
102        return false;
103    int offset = mOffset;
104    value = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) |
105           ((uint64_t)mBuffer[offset + 2] << 16) | ((uint64_t)mBuffer[offset + 3] << 24) |
106           ((uint64_t)mBuffer[offset + 4] << 32) | ((uint64_t)mBuffer[offset + 5] << 40) |
107           ((uint64_t)mBuffer[offset + 6] << 48)  | ((uint64_t)mBuffer[offset + 7] << 56);
108    mOffset += sizeof(value);
109    return true;
110}
111
112bool MtpDataPacket::getUInt128(uint128_t& value) {
113    return getUInt32(value[0]) && getUInt32(value[1]) && getUInt32(value[2]) && getUInt32(value[3]);
114}
115
116bool MtpDataPacket::getString(MtpStringBuffer& string)
117{
118    return string.readFromPacket(this);
119}
120
121Int8List* MtpDataPacket::getAInt8() {
122    uint32_t count;
123    if (!getUInt32(count))
124        return NULL;
125    Int8List* result = new Int8List;
126    for (uint32_t i = 0; i < count; i++) {
127        int8_t value;
128        if (!getInt8(value)) {
129            delete result;
130            return NULL;
131        }
132        result->push(value);
133    }
134    return result;
135}
136
137UInt8List* MtpDataPacket::getAUInt8() {
138    uint32_t count;
139    if (!getUInt32(count))
140        return NULL;
141    UInt8List* result = new UInt8List;
142    for (uint32_t i = 0; i < count; i++) {
143        uint8_t value;
144        if (!getUInt8(value)) {
145            delete result;
146            return NULL;
147        }
148        result->push(value);
149    }
150    return result;
151}
152
153Int16List* MtpDataPacket::getAInt16() {
154    uint32_t count;
155    if (!getUInt32(count))
156        return NULL;
157    Int16List* result = new Int16List;
158    for (uint32_t i = 0; i < count; i++) {
159        int16_t value;
160        if (!getInt16(value)) {
161            delete result;
162            return NULL;
163        }
164        result->push(value);
165    }
166    return result;
167}
168
169UInt16List* MtpDataPacket::getAUInt16() {
170    uint32_t count;
171    if (!getUInt32(count))
172        return NULL;
173    UInt16List* result = new UInt16List;
174    for (uint32_t i = 0; i < count; i++) {
175        uint16_t value;
176        if (!getUInt16(value)) {
177            delete result;
178            return NULL;
179        }
180        result->push(value);
181    }
182    return result;
183}
184
185Int32List* MtpDataPacket::getAInt32() {
186    uint32_t count;
187    if (!getUInt32(count))
188        return NULL;
189    Int32List* result = new Int32List;
190    for (uint32_t i = 0; i < count; i++) {
191        int32_t value;
192        if (!getInt32(value)) {
193            delete result;
194            return NULL;
195        }
196        result->push(value);
197    }
198    return result;
199}
200
201UInt32List* MtpDataPacket::getAUInt32() {
202    uint32_t count;
203    if (!getUInt32(count))
204        return NULL;
205    UInt32List* result = new UInt32List;
206    for (uint32_t i = 0; i < count; i++) {
207        uint32_t value;
208        if (!getUInt32(value)) {
209            delete result;
210            return NULL;
211        }
212        result->push(value);
213    }
214    return result;
215}
216
217Int64List* MtpDataPacket::getAInt64() {
218    uint32_t count;
219    if (!getUInt32(count))
220        return NULL;
221    Int64List* result = new Int64List;
222    for (uint32_t i = 0; i < count; i++) {
223        int64_t value;
224        if (!getInt64(value)) {
225            delete result;
226            return NULL;
227        }
228        result->push(value);
229    }
230    return result;
231}
232
233UInt64List* MtpDataPacket::getAUInt64() {
234    uint32_t count;
235    if (!getUInt32(count))
236        return NULL;
237    UInt64List* result = new UInt64List;
238    for (uint32_t i = 0; i < count; i++) {
239        uint64_t value;
240        if (!getUInt64(value)) {
241            delete result;
242            return NULL;
243        }
244        result->push(value);
245    }
246    return result;
247}
248
249void MtpDataPacket::putInt8(int8_t value) {
250    allocate(mOffset + 1);
251    mBuffer[mOffset++] = (uint8_t)value;
252    if (mPacketSize < mOffset)
253        mPacketSize = mOffset;
254}
255
256void MtpDataPacket::putUInt8(uint8_t value) {
257    allocate(mOffset + 1);
258    mBuffer[mOffset++] = (uint8_t)value;
259    if (mPacketSize < mOffset)
260        mPacketSize = mOffset;
261}
262
263void MtpDataPacket::putInt16(int16_t value) {
264    allocate(mOffset + 2);
265    mBuffer[mOffset++] = (uint8_t)(value & 0xFF);
266    mBuffer[mOffset++] = (uint8_t)((value >> 8) & 0xFF);
267    if (mPacketSize < mOffset)
268        mPacketSize = mOffset;
269}
270
271void MtpDataPacket::putUInt16(uint16_t value) {
272    allocate(mOffset + 2);
273    mBuffer[mOffset++] = (uint8_t)(value & 0xFF);
274    mBuffer[mOffset++] = (uint8_t)((value >> 8) & 0xFF);
275    if (mPacketSize < mOffset)
276        mPacketSize = mOffset;
277}
278
279void MtpDataPacket::putInt32(int32_t value) {
280    allocate(mOffset + 4);
281    mBuffer[mOffset++] = (uint8_t)(value & 0xFF);
282    mBuffer[mOffset++] = (uint8_t)((value >> 8) & 0xFF);
283    mBuffer[mOffset++] = (uint8_t)((value >> 16) & 0xFF);
284    mBuffer[mOffset++] = (uint8_t)((value >> 24) & 0xFF);
285    if (mPacketSize < mOffset)
286        mPacketSize = mOffset;
287}
288
289void MtpDataPacket::putUInt32(uint32_t value) {
290    allocate(mOffset + 4);
291    mBuffer[mOffset++] = (uint8_t)(value & 0xFF);
292    mBuffer[mOffset++] = (uint8_t)((value >> 8) & 0xFF);
293    mBuffer[mOffset++] = (uint8_t)((value >> 16) & 0xFF);
294    mBuffer[mOffset++] = (uint8_t)((value >> 24) & 0xFF);
295    if (mPacketSize < mOffset)
296        mPacketSize = mOffset;
297}
298
299void MtpDataPacket::putInt64(int64_t value) {
300    allocate(mOffset + 8);
301    mBuffer[mOffset++] = (uint8_t)(value & 0xFF);
302    mBuffer[mOffset++] = (uint8_t)((value >> 8) & 0xFF);
303    mBuffer[mOffset++] = (uint8_t)((value >> 16) & 0xFF);
304    mBuffer[mOffset++] = (uint8_t)((value >> 24) & 0xFF);
305    mBuffer[mOffset++] = (uint8_t)((value >> 32) & 0xFF);
306    mBuffer[mOffset++] = (uint8_t)((value >> 40) & 0xFF);
307    mBuffer[mOffset++] = (uint8_t)((value >> 48) & 0xFF);
308    mBuffer[mOffset++] = (uint8_t)((value >> 56) & 0xFF);
309    if (mPacketSize < mOffset)
310        mPacketSize = mOffset;
311}
312
313void MtpDataPacket::putUInt64(uint64_t value) {
314    allocate(mOffset + 8);
315    mBuffer[mOffset++] = (uint8_t)(value & 0xFF);
316    mBuffer[mOffset++] = (uint8_t)((value >> 8) & 0xFF);
317    mBuffer[mOffset++] = (uint8_t)((value >> 16) & 0xFF);
318    mBuffer[mOffset++] = (uint8_t)((value >> 24) & 0xFF);
319    mBuffer[mOffset++] = (uint8_t)((value >> 32) & 0xFF);
320    mBuffer[mOffset++] = (uint8_t)((value >> 40) & 0xFF);
321    mBuffer[mOffset++] = (uint8_t)((value >> 48) & 0xFF);
322    mBuffer[mOffset++] = (uint8_t)((value >> 56) & 0xFF);
323    if (mPacketSize < mOffset)
324        mPacketSize = mOffset;
325}
326
327void MtpDataPacket::putInt128(const int128_t& value) {
328    putInt32(value[0]);
329    putInt32(value[1]);
330    putInt32(value[2]);
331    putInt32(value[3]);
332}
333
334void MtpDataPacket::putUInt128(const uint128_t& value) {
335    putUInt32(value[0]);
336    putUInt32(value[1]);
337    putUInt32(value[2]);
338    putUInt32(value[3]);
339}
340
341void MtpDataPacket::putInt128(int64_t value) {
342    putInt64(value);
343    putInt64(value < 0 ? -1 : 0);
344}
345
346void MtpDataPacket::putUInt128(uint64_t value) {
347    putUInt64(value);
348    putUInt64(0);
349}
350
351void MtpDataPacket::putAInt8(const int8_t* values, int count) {
352    putUInt32(count);
353    for (int i = 0; i < count; i++)
354        putInt8(*values++);
355}
356
357void MtpDataPacket::putAUInt8(const uint8_t* values, int count) {
358    putUInt32(count);
359    for (int i = 0; i < count; i++)
360        putUInt8(*values++);
361}
362
363void MtpDataPacket::putAInt16(const int16_t* values, int count) {
364    putUInt32(count);
365    for (int i = 0; i < count; i++)
366        putInt16(*values++);
367}
368
369void MtpDataPacket::putAUInt16(const uint16_t* values, int count) {
370    putUInt32(count);
371    for (int i = 0; i < count; i++)
372        putUInt16(*values++);
373}
374
375void MtpDataPacket::putAUInt16(const UInt16List* values) {
376    size_t count = (values ? values->size() : 0);
377    putUInt32(count);
378    for (size_t i = 0; i < count; i++)
379        putUInt16((*values)[i]);
380}
381
382void MtpDataPacket::putAInt32(const int32_t* values, int count) {
383    putUInt32(count);
384    for (int i = 0; i < count; i++)
385        putInt32(*values++);
386}
387
388void MtpDataPacket::putAUInt32(const uint32_t* values, int count) {
389    putUInt32(count);
390    for (int i = 0; i < count; i++)
391        putUInt32(*values++);
392}
393
394void MtpDataPacket::putAUInt32(const UInt32List* list) {
395    if (!list) {
396        putEmptyArray();
397    } else {
398        size_t size = list->size();
399        putUInt32(size);
400        for (size_t i = 0; i < size; i++)
401            putUInt32((*list)[i]);
402    }
403}
404
405void MtpDataPacket::putAInt64(const int64_t* values, int count) {
406    putUInt32(count);
407    for (int i = 0; i < count; i++)
408        putInt64(*values++);
409}
410
411void MtpDataPacket::putAUInt64(const uint64_t* values, int count) {
412    putUInt32(count);
413    for (int i = 0; i < count; i++)
414        putUInt64(*values++);
415}
416
417void MtpDataPacket::putString(const MtpStringBuffer& string) {
418    string.writeToPacket(this);
419}
420
421void MtpDataPacket::putString(const char* s) {
422    MtpStringBuffer string(s);
423    string.writeToPacket(this);
424}
425
426void MtpDataPacket::putString(const uint16_t* string) {
427    int count = 0;
428    for (int i = 0; i <= MTP_STRING_MAX_CHARACTER_NUMBER; i++) {
429        if (string[i])
430            count++;
431        else
432            break;
433    }
434    putUInt8(count > 0 ? count + 1 : 0);
435    for (int i = 0; i < count; i++)
436        putUInt16(string[i]);
437    // only terminate with zero if string is not empty
438    if (count > 0)
439        putUInt16(0);
440}
441
442#ifdef MTP_DEVICE
443int MtpDataPacket::read(IMtpHandle *h) {
444    int ret = h->read(mBuffer, MTP_BUFFER_SIZE);
445    if (ret < MTP_CONTAINER_HEADER_SIZE)
446        return -1;
447    mPacketSize = ret;
448    mOffset = MTP_CONTAINER_HEADER_SIZE;
449    return ret;
450}
451
452int MtpDataPacket::write(IMtpHandle *h) {
453    MtpPacket::putUInt32(MTP_CONTAINER_LENGTH_OFFSET, mPacketSize);
454    MtpPacket::putUInt16(MTP_CONTAINER_TYPE_OFFSET, MTP_CONTAINER_TYPE_DATA);
455    int ret = h->write(mBuffer, mPacketSize);
456    return (ret < 0 ? ret : 0);
457}
458
459int MtpDataPacket::writeData(IMtpHandle *h, void* data, uint32_t length) {
460    allocate(length + MTP_CONTAINER_HEADER_SIZE);
461    memcpy(mBuffer + MTP_CONTAINER_HEADER_SIZE, data, length);
462    length += MTP_CONTAINER_HEADER_SIZE;
463    MtpPacket::putUInt32(MTP_CONTAINER_LENGTH_OFFSET, length);
464    MtpPacket::putUInt16(MTP_CONTAINER_TYPE_OFFSET, MTP_CONTAINER_TYPE_DATA);
465    int ret = h->write(mBuffer, length);
466    return (ret < 0 ? ret : 0);
467}
468
469#endif // MTP_DEVICE
470
471#ifdef MTP_HOST
472int MtpDataPacket::read(struct usb_request *request) {
473    // first read the header
474    request->buffer = mBuffer;
475    request->buffer_length = mBufferSize;
476    int length = transfer(request);
477    if (length >= MTP_CONTAINER_HEADER_SIZE) {
478        // look at the length field to see if the data spans multiple packets
479        uint32_t totalLength = MtpPacket::getUInt32(MTP_CONTAINER_LENGTH_OFFSET);
480        allocate(totalLength);
481        while (totalLength > static_cast<uint32_t>(length)) {
482            request->buffer = mBuffer + length;
483            request->buffer_length = totalLength - length;
484            int ret = transfer(request);
485            if (ret >= 0)
486                length += ret;
487            else {
488                length = ret;
489                break;
490            }
491        }
492    }
493    if (length >= 0)
494        mPacketSize = length;
495    return length;
496}
497
498int MtpDataPacket::readData(struct usb_request *request, void* buffer, int length) {
499    int read = 0;
500    while (read < length) {
501        request->buffer = (char *)buffer + read;
502        request->buffer_length = length - read;
503        int ret = transfer(request);
504        if (ret < 0) {
505            return ret;
506        }
507        read += ret;
508    }
509    return read;
510}
511
512// Queue a read request.  Call readDataWait to wait for result
513int MtpDataPacket::readDataAsync(struct usb_request *req) {
514    if (usb_request_queue(req)) {
515        ALOGE("usb_endpoint_queue failed, errno: %d", errno);
516        return -1;
517    }
518    return 0;
519}
520
521// Wait for result of readDataAsync
522int MtpDataPacket::readDataWait(struct usb_device *device) {
523    struct usb_request *req = usb_request_wait(device, -1);
524    return (req ? req->actual_length : -1);
525}
526
527int MtpDataPacket::readDataHeader(struct usb_request *request) {
528    request->buffer = mBuffer;
529    request->buffer_length = request->max_packet_size;
530    int length = transfer(request);
531    if (length >= 0)
532        mPacketSize = length;
533    return length;
534}
535
536int MtpDataPacket::write(struct usb_request *request, UrbPacketDivisionMode divisionMode) {
537    if (mPacketSize < MTP_CONTAINER_HEADER_SIZE || mPacketSize > MTP_BUFFER_SIZE) {
538        ALOGE("Illegal packet size.");
539        return -1;
540    }
541
542    MtpPacket::putUInt32(MTP_CONTAINER_LENGTH_OFFSET, mPacketSize);
543    MtpPacket::putUInt16(MTP_CONTAINER_TYPE_OFFSET, MTP_CONTAINER_TYPE_DATA);
544
545    size_t processedBytes = 0;
546    while (processedBytes < mPacketSize) {
547        const size_t write_size =
548                processedBytes == 0 && divisionMode == FIRST_PACKET_ONLY_HEADER ?
549                        MTP_CONTAINER_HEADER_SIZE : mPacketSize - processedBytes;
550        request->buffer = mBuffer + processedBytes;
551        request->buffer_length = write_size;
552        const int result = transfer(request);
553        if (result < 0) {
554            ALOGE("Failed to write bytes to the device.");
555            return -1;
556        }
557        processedBytes += result;
558    }
559
560    return processedBytes == mPacketSize ? processedBytes : -1;
561}
562
563int MtpDataPacket::write(struct usb_request *request,
564                         UrbPacketDivisionMode divisionMode,
565                         int fd,
566                         size_t payloadSize) {
567    // Obtain the greatest multiple of minimum packet size that is not greater than
568    // MTP_BUFFER_SIZE.
569    if (request->max_packet_size <= 0) {
570        ALOGE("Cannot determine bulk transfer size due to illegal max packet size %d.",
571              request->max_packet_size);
572        return -1;
573    }
574    const size_t maxBulkTransferSize =
575            MTP_BUFFER_SIZE - (MTP_BUFFER_SIZE % request->max_packet_size);
576    const size_t containerLength = payloadSize + MTP_CONTAINER_HEADER_SIZE;
577    size_t processedBytes = 0;
578    bool readError = false;
579
580    // Bind the packet with given request.
581    request->buffer = mBuffer;
582    allocate(maxBulkTransferSize);
583
584    while (processedBytes < containerLength) {
585        size_t bulkTransferSize = 0;
586
587        // prepare header.
588        const bool headerSent = processedBytes != 0;
589        if (!headerSent) {
590            MtpPacket::putUInt32(MTP_CONTAINER_LENGTH_OFFSET, containerLength);
591            MtpPacket::putUInt16(MTP_CONTAINER_TYPE_OFFSET, MTP_CONTAINER_TYPE_DATA);
592            bulkTransferSize += MTP_CONTAINER_HEADER_SIZE;
593        }
594
595        // Prepare payload.
596        if (headerSent || divisionMode == FIRST_PACKET_HAS_PAYLOAD) {
597            const size_t processedPayloadBytes =
598                    headerSent ? processedBytes - MTP_CONTAINER_HEADER_SIZE : 0;
599            const size_t maxRead = payloadSize - processedPayloadBytes;
600            const size_t maxWrite = maxBulkTransferSize - bulkTransferSize;
601            const size_t bulkTransferPayloadSize = std::min(maxRead, maxWrite);
602            // prepare payload.
603            if (!readError) {
604                const ssize_t result = readExactBytes(
605                        fd,
606                        mBuffer + bulkTransferSize,
607                        bulkTransferPayloadSize);
608                if (result < 0) {
609                    ALOGE("Found an error while reading data from FD. Send 0 data instead.");
610                    readError = true;
611                }
612            }
613            if (readError) {
614                memset(mBuffer + bulkTransferSize, 0, bulkTransferPayloadSize);
615            }
616            bulkTransferSize += bulkTransferPayloadSize;
617        }
618
619        // Bulk transfer.
620        mPacketSize = bulkTransferSize;
621        request->buffer_length = bulkTransferSize;
622        const int result = transfer(request);
623        if (result != static_cast<ssize_t>(bulkTransferSize)) {
624            // Cannot recover writing error.
625            ALOGE("Found an error while write data to MtpDevice.");
626            return -1;
627        }
628
629        // Update variables.
630        processedBytes += bulkTransferSize;
631    }
632
633    return readError ? -1 : processedBytes;
634}
635
636#endif // MTP_HOST
637
638void* MtpDataPacket::getData(int* outLength) const {
639    int length = mPacketSize - MTP_CONTAINER_HEADER_SIZE;
640    if (length > 0) {
641        void* result = malloc(length);
642        if (result) {
643            memcpy(result, mBuffer + MTP_CONTAINER_HEADER_SIZE, length);
644            *outLength = length;
645            return result;
646        }
647    }
648    *outLength = 0;
649    return NULL;
650}
651
652}  // namespace android
653