1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 *  * Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 *  * Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in
12 *    the documentation and/or other materials provided with the
13 *    distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
18 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
19 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
20 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
21 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
22 * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
23 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
25 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29// This file implements the fastboot UDP protocol; see fastboot_protocol.txt for documentation.
30
31#include "udp.h"
32
33#include <errno.h>
34#include <stdio.h>
35
36#include <list>
37#include <memory>
38#include <vector>
39
40#include <android-base/macros.h>
41#include <android-base/stringprintf.h>
42
43#include "socket.h"
44
45namespace udp {
46
47using namespace internal;
48
49constexpr size_t kMinPacketSize = 512;
50constexpr size_t kHeaderSize = 4;
51
52enum Index {
53    kIndexId = 0,
54    kIndexFlags = 1,
55    kIndexSeqH = 2,
56    kIndexSeqL = 3,
57};
58
59// Extracts a big-endian uint16_t from a byte array.
60static uint16_t ExtractUint16(const uint8_t* bytes) {
61    return (static_cast<uint16_t>(bytes[0]) << 8) | bytes[1];
62}
63
64// Packet header handling.
65class Header {
66  public:
67    Header();
68    ~Header() = default;
69
70    uint8_t id() const { return bytes_[kIndexId]; }
71    const uint8_t* bytes() const { return bytes_; }
72
73    void Set(uint8_t id, uint16_t sequence, Flag flag);
74
75    // Checks whether |response| is a match for this header.
76    bool Matches(const uint8_t* response);
77
78  private:
79    uint8_t bytes_[kHeaderSize];
80};
81
82Header::Header() {
83    Set(kIdError, 0, kFlagNone);
84}
85
86void Header::Set(uint8_t id, uint16_t sequence, Flag flag) {
87    bytes_[kIndexId] = id;
88    bytes_[kIndexFlags] = flag;
89    bytes_[kIndexSeqH] = sequence >> 8;
90    bytes_[kIndexSeqL] = sequence;
91}
92
93bool Header::Matches(const uint8_t* response) {
94    // Sequence numbers must be the same to match, but the response ID can either be the same
95    // or an error response which is always accepted.
96    return bytes_[kIndexSeqH] == response[kIndexSeqH] &&
97           bytes_[kIndexSeqL] == response[kIndexSeqL] &&
98           (bytes_[kIndexId] == response[kIndexId] || response[kIndexId] == kIdError);
99}
100
101// Implements the Transport interface to work with the fastboot engine.
102class UdpTransport : public Transport {
103  public:
104    // Factory function so we can return nullptr if initialization fails.
105    static std::unique_ptr<UdpTransport> NewTransport(std::unique_ptr<Socket> socket,
106                                                      std::string* error);
107    ~UdpTransport() override = default;
108
109    ssize_t Read(void* data, size_t length) override;
110    ssize_t Write(const void* data, size_t length) override;
111    int Close() override;
112
113  private:
114    UdpTransport(std::unique_ptr<Socket> socket) : socket_(std::move(socket)) {}
115
116    // Performs the UDP initialization procedure. Returns true on success.
117    bool InitializeProtocol(std::string* error);
118
119    // Sends |length| bytes from |data| and waits for the response packet up to |attempts| times.
120    // Continuation packets are handled automatically and any return data is written to |rx_data|.
121    // Excess bytes that cannot fit in |rx_data| are dropped.
122    // On success, returns the number of response data bytes received, which may be greater than
123    // |rx_length|. On failure, returns -1 and fills |error| on failure.
124    ssize_t SendData(Id id, const uint8_t* tx_data, size_t tx_length, uint8_t* rx_data,
125                     size_t rx_length, int attempts, std::string* error);
126
127    // Helper for SendData(); sends a single packet and handles the response. |header| specifies
128    // the initial outgoing packet information but may be modified by this function.
129    ssize_t SendSinglePacketHelper(Header* header, const uint8_t* tx_data, size_t tx_length,
130                                   uint8_t* rx_data, size_t rx_length, int attempts,
131                                   std::string* error);
132
133    std::unique_ptr<Socket> socket_;
134    int sequence_ = -1;
135    size_t max_data_length_ = kMinPacketSize - kHeaderSize;
136    std::vector<uint8_t> rx_packet_;
137
138    DISALLOW_COPY_AND_ASSIGN(UdpTransport);
139};
140
141std::unique_ptr<UdpTransport> UdpTransport::NewTransport(std::unique_ptr<Socket> socket,
142                                                         std::string* error) {
143    std::unique_ptr<UdpTransport> transport(new UdpTransport(std::move(socket)));
144
145    if (!transport->InitializeProtocol(error)) {
146        return nullptr;
147    }
148
149    return transport;
150}
151
152bool UdpTransport::InitializeProtocol(std::string* error) {
153    uint8_t rx_data[4];
154
155    sequence_ = 0;
156    rx_packet_.resize(kMinPacketSize);
157
158    // First send the query packet to sync with the target. Only attempt this a small number of
159    // times so we can fail out quickly if the target isn't available.
160    ssize_t rx_bytes = SendData(kIdDeviceQuery, nullptr, 0, rx_data, sizeof(rx_data),
161                                kMaxConnectAttempts, error);
162    if (rx_bytes == -1) {
163        return false;
164    } else if (rx_bytes < 2) {
165        *error = "invalid query response from target";
166        return false;
167    }
168    // The first two bytes contain the next expected sequence number.
169    sequence_ = ExtractUint16(rx_data);
170
171    // Now send the initialization packet with our version and maximum packet size.
172    uint8_t init_data[] = {kProtocolVersion >> 8, kProtocolVersion & 0xFF,
173                           kHostMaxPacketSize >> 8, kHostMaxPacketSize & 0xFF};
174    rx_bytes = SendData(kIdInitialization, init_data, sizeof(init_data), rx_data, sizeof(rx_data),
175                        kMaxTransmissionAttempts, error);
176    if (rx_bytes == -1) {
177        return false;
178    } else if (rx_bytes < 4) {
179        *error = "invalid initialization response from target";
180        return false;
181    }
182
183    // The first two data bytes contain the version, the second two bytes contain the target max
184    // supported packet size, which must be at least 512 bytes.
185    uint16_t version = ExtractUint16(rx_data);
186    if (version < kProtocolVersion) {
187        *error = android::base::StringPrintf("target reported invalid protocol version %d",
188                                             version);
189        return false;
190    }
191    uint16_t packet_size = ExtractUint16(rx_data + 2);
192    if (packet_size < kMinPacketSize) {
193        *error = android::base::StringPrintf("target reported invalid packet size %d", packet_size);
194        return false;
195    }
196
197    packet_size = std::min(kHostMaxPacketSize, packet_size);
198    max_data_length_ = packet_size - kHeaderSize;
199    rx_packet_.resize(packet_size);
200
201    return true;
202}
203
204// SendData() is just responsible for chunking |data| into packets until it's all been sent.
205// Per-packet timeout/retransmission logic is done in SendSinglePacketHelper().
206ssize_t UdpTransport::SendData(Id id, const uint8_t* tx_data, size_t tx_length, uint8_t* rx_data,
207                               size_t rx_length, int attempts, std::string* error) {
208    if (socket_ == nullptr) {
209        *error = "socket is closed";
210        return -1;
211    }
212
213    Header header;
214    size_t packet_data_length;
215    ssize_t ret = 0;
216    // We often send header-only packets with no data as part of the protocol, so always send at
217    // least once even if |length| == 0, then repeat until we've sent all of |data|.
218    do {
219        // Set the continuation flag and truncate packet data if needed.
220        if (tx_length > max_data_length_) {
221            packet_data_length = max_data_length_;
222            header.Set(id, sequence_, kFlagContinuation);
223        } else {
224            packet_data_length = tx_length;
225            header.Set(id, sequence_, kFlagNone);
226        }
227
228        ssize_t bytes = SendSinglePacketHelper(&header, tx_data, packet_data_length, rx_data,
229                                               rx_length, attempts, error);
230
231        // Advance our read and write buffers for the next packet. Keep going even if we run out
232        // of receive buffer space so we can detect overflows.
233        if (bytes == -1) {
234            return -1;
235        } else if (static_cast<size_t>(bytes) < rx_length) {
236            rx_data += bytes;
237            rx_length -= bytes;
238        } else {
239            rx_data = nullptr;
240            rx_length = 0;
241        }
242
243        tx_length -= packet_data_length;
244        tx_data += packet_data_length;
245
246        ret += bytes;
247    } while (tx_length > 0);
248
249    return ret;
250}
251
252ssize_t UdpTransport::SendSinglePacketHelper(
253        Header* header, const uint8_t* tx_data, size_t tx_length, uint8_t* rx_data,
254        size_t rx_length, const int attempts, std::string* error) {
255    ssize_t total_data_bytes = 0;
256    error->clear();
257
258    int attempts_left = attempts;
259    while (attempts_left > 0) {
260        if (!socket_->Send({{header->bytes(), kHeaderSize}, {tx_data, tx_length}})) {
261            *error = Socket::GetErrorMessage();
262            return -1;
263        }
264
265        // Keep receiving until we get a matching response or we timeout.
266        ssize_t bytes = 0;
267        do {
268            bytes = socket_->Receive(rx_packet_.data(), rx_packet_.size(), kResponseTimeoutMs);
269            if (bytes == -1) {
270                if (socket_->ReceiveTimedOut()) {
271                    break;
272                }
273                *error = Socket::GetErrorMessage();
274                return -1;
275            } else if (bytes < static_cast<ssize_t>(kHeaderSize)) {
276                *error = "protocol error: incomplete header";
277                return -1;
278            }
279        } while (!header->Matches(rx_packet_.data()));
280
281        if (socket_->ReceiveTimedOut()) {
282            --attempts_left;
283            continue;
284        }
285        ++sequence_;
286
287        // Save to |error| or |rx_data| as appropriate.
288        if (rx_packet_[kIndexId] == kIdError) {
289            error->append(rx_packet_.data() + kHeaderSize, rx_packet_.data() + bytes);
290        } else {
291            total_data_bytes += bytes - kHeaderSize;
292            size_t rx_data_bytes = std::min<size_t>(bytes - kHeaderSize, rx_length);
293            if (rx_data_bytes > 0) {
294                memcpy(rx_data, rx_packet_.data() + kHeaderSize, rx_data_bytes);
295                rx_data += rx_data_bytes;
296                rx_length -= rx_data_bytes;
297            }
298        }
299
300        // If the response has a continuation flag we need to prompt for more data by sending
301        // an empty packet.
302        if (rx_packet_[kIndexFlags] & kFlagContinuation) {
303            // We got a valid response so reset our attempt counter.
304            attempts_left = attempts;
305            header->Set(rx_packet_[kIndexId], sequence_, kFlagNone);
306            tx_data = nullptr;
307            tx_length = 0;
308            continue;
309        }
310
311        break;
312    }
313
314    if (attempts_left <= 0) {
315        *error = "no response from target";
316        return -1;
317    }
318
319    if (rx_packet_[kIndexId] == kIdError) {
320        *error = "target reported error: " + *error;
321        return -1;
322    }
323
324    return total_data_bytes;
325}
326
327ssize_t UdpTransport::Read(void* data, size_t length) {
328    // Read from the target by sending an empty packet.
329    std::string error;
330    ssize_t bytes = SendData(kIdFastboot, nullptr, 0, reinterpret_cast<uint8_t*>(data), length,
331                             kMaxTransmissionAttempts, &error);
332
333    if (bytes == -1) {
334        fprintf(stderr, "UDP error: %s\n", error.c_str());
335        return -1;
336    } else if (static_cast<size_t>(bytes) > length) {
337        // Fastboot protocol error: the target sent more data than our fastboot engine was prepared
338        // to receive.
339        fprintf(stderr, "UDP error: receive overflow, target sent too much fastboot data\n");
340        return -1;
341    }
342
343    return bytes;
344}
345
346ssize_t UdpTransport::Write(const void* data, size_t length) {
347    std::string error;
348    ssize_t bytes = SendData(kIdFastboot, reinterpret_cast<const uint8_t*>(data), length, nullptr,
349                             0, kMaxTransmissionAttempts, &error);
350
351    if (bytes == -1) {
352        fprintf(stderr, "UDP error: %s\n", error.c_str());
353        return -1;
354    } else if (bytes > 0) {
355        // UDP protocol error: only empty ACK packets are allowed when writing to a device.
356        fprintf(stderr, "UDP error: target sent fastboot data out-of-turn\n");
357        return -1;
358    }
359
360    return length;
361}
362
363int UdpTransport::Close() {
364    if (socket_ == nullptr) {
365        return 0;
366    }
367
368    int result = socket_->Close();
369    socket_.reset();
370    return result;
371}
372
373std::unique_ptr<Transport> Connect(const std::string& hostname, int port, std::string* error) {
374    return internal::Connect(Socket::NewClient(Socket::Protocol::kUdp, hostname, port, error),
375                             error);
376}
377
378namespace internal {
379
380std::unique_ptr<Transport> Connect(std::unique_ptr<Socket> sock, std::string* error) {
381    if (sock == nullptr) {
382        // If Socket creation failed |error| is already set.
383        return nullptr;
384    }
385
386    return UdpTransport::NewTransport(std::move(sock), error);
387}
388
389}  // namespace internal
390
391}  // namespace udp
392