1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 *  * Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 *  * Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in
12 *    the documentation and/or other materials provided with the
13 *    distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
18 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
19 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
20 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
21 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
22 * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
23 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
25 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29#include "tcp.h"
30
31#include <android-base/parseint.h>
32#include <android-base/stringprintf.h>
33
34namespace tcp {
35
36static constexpr int kProtocolVersion = 1;
37static constexpr size_t kHandshakeLength = 4;
38static constexpr int kHandshakeTimeoutMs = 2000;
39
40// Extract the big-endian 8-byte message length into a 64-bit number.
41static uint64_t ExtractMessageLength(const void* buffer) {
42    uint64_t ret = 0;
43    for (int i = 0; i < 8; ++i) {
44        ret |= uint64_t{reinterpret_cast<const uint8_t*>(buffer)[i]} << (56 - i * 8);
45    }
46    return ret;
47}
48
49// Encode the 64-bit number into a big-endian 8-byte message length.
50static void EncodeMessageLength(uint64_t length, void* buffer) {
51    for (int i = 0; i < 8; ++i) {
52        reinterpret_cast<uint8_t*>(buffer)[i] = length >> (56 - i * 8);
53    }
54}
55
56class TcpTransport : public Transport {
57  public:
58    // Factory function so we can return nullptr if initialization fails.
59    static std::unique_ptr<TcpTransport> NewTransport(std::unique_ptr<Socket> socket,
60                                                      std::string* error);
61
62    ~TcpTransport() override = default;
63
64    ssize_t Read(void* data, size_t length) override;
65    ssize_t Write(const void* data, size_t length) override;
66    int Close() override;
67
68  private:
69    TcpTransport(std::unique_ptr<Socket> sock) : socket_(std::move(sock)) {}
70
71    // Connects to the device and performs the initial handshake. Returns false and fills |error|
72    // on failure.
73    bool InitializeProtocol(std::string* error);
74
75    std::unique_ptr<Socket> socket_;
76    uint64_t message_bytes_left_ = 0;
77
78    DISALLOW_COPY_AND_ASSIGN(TcpTransport);
79};
80
81std::unique_ptr<TcpTransport> TcpTransport::NewTransport(std::unique_ptr<Socket> socket,
82                                                         std::string* error) {
83    std::unique_ptr<TcpTransport> transport(new TcpTransport(std::move(socket)));
84
85    if (!transport->InitializeProtocol(error)) {
86        return nullptr;
87    }
88
89    return transport;
90}
91
92// These error strings are checked in tcp_test.cpp and should be kept in sync.
93bool TcpTransport::InitializeProtocol(std::string* error) {
94    std::string handshake_message(android::base::StringPrintf("FB%02d", kProtocolVersion));
95
96    if (!socket_->Send(handshake_message.c_str(), kHandshakeLength)) {
97        *error = android::base::StringPrintf("Failed to send initialization message (%s)",
98                                             Socket::GetErrorMessage().c_str());
99        return false;
100    }
101
102    char buffer[kHandshakeLength + 1];
103    buffer[kHandshakeLength] = '\0';
104    if (socket_->ReceiveAll(buffer, kHandshakeLength, kHandshakeTimeoutMs) != kHandshakeLength) {
105        *error = android::base::StringPrintf(
106                "No initialization message received (%s). Target may not support TCP fastboot",
107                Socket::GetErrorMessage().c_str());
108        return false;
109    }
110
111    if (memcmp(buffer, "FB", 2) != 0) {
112        *error = "Unrecognized initialization message. Target may not support TCP fastboot";
113        return false;
114    }
115
116    int version = 0;
117    if (!android::base::ParseInt(buffer + 2, &version) || version < kProtocolVersion) {
118        *error = android::base::StringPrintf("Unknown TCP protocol version %s (host version %02d)",
119                                             buffer + 2, kProtocolVersion);
120        return false;
121    }
122
123    error->clear();
124    return true;
125}
126
127ssize_t TcpTransport::Read(void* data, size_t length) {
128    if (socket_ == nullptr) {
129        return -1;
130    }
131
132    // Unless we're mid-message, read the next 8-byte message length.
133    if (message_bytes_left_ == 0) {
134        char buffer[8];
135        if (socket_->ReceiveAll(buffer, 8, 0) != 8) {
136            Close();
137            return -1;
138        }
139        message_bytes_left_ = ExtractMessageLength(buffer);
140    }
141
142    // Now read the message (up to |length| bytes).
143    if (length > message_bytes_left_) {
144        length = message_bytes_left_;
145    }
146    ssize_t bytes_read = socket_->ReceiveAll(data, length, 0);
147    if (bytes_read == -1) {
148        Close();
149    } else {
150        message_bytes_left_ -= bytes_read;
151    }
152    return bytes_read;
153}
154
155ssize_t TcpTransport::Write(const void* data, size_t length) {
156    if (socket_ == nullptr) {
157        return -1;
158    }
159
160    // Use multi-buffer writes for better performance.
161    char header[8];
162    EncodeMessageLength(length, header);
163    if (!socket_->Send(std::vector<cutils_socket_buffer_t>{{header, 8}, {data, length}})) {
164        Close();
165        return -1;
166    }
167
168    return length;
169}
170
171int TcpTransport::Close() {
172    if (socket_ == nullptr) {
173        return 0;
174    }
175
176    int result = socket_->Close();
177    socket_.reset();
178    return result;
179}
180
181std::unique_ptr<Transport> Connect(const std::string& hostname, int port, std::string* error) {
182    return internal::Connect(Socket::NewClient(Socket::Protocol::kTcp, hostname, port, error),
183                             error);
184}
185
186namespace internal {
187
188std::unique_ptr<Transport> Connect(std::unique_ptr<Socket> sock, std::string* error) {
189    if (sock == nullptr) {
190        // If Socket creation failed |error| is already set.
191        return nullptr;
192    }
193
194    return TcpTransport::NewTransport(std::move(sock), error);
195}
196
197}  // namespace internal
198
199}  // namespace tcp
200