1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/host/gnubby_auth_handler_posix.h"
6
7#include <unistd.h>
8#include <utility>
9
10#include "base/bind.h"
11#include "base/files/file_util.h"
12#include "base/json/json_reader.h"
13#include "base/json/json_writer.h"
14#include "base/lazy_instance.h"
15#include "base/stl_util.h"
16#include "base/values.h"
17#include "net/socket/unix_domain_listen_socket_posix.h"
18#include "remoting/base/logging.h"
19#include "remoting/host/gnubby_socket.h"
20#include "remoting/proto/control.pb.h"
21#include "remoting/protocol/client_stub.h"
22
23namespace remoting {
24
25namespace {
26
27const char kConnectionId[] = "connectionId";
28const char kControlMessage[] = "control";
29const char kControlOption[] = "option";
30const char kDataMessage[] = "data";
31const char kDataPayload[] = "data";
32const char kErrorMessage[] = "error";
33const char kGnubbyAuthMessage[] = "gnubby-auth";
34const char kGnubbyAuthV1[] = "auth-v1";
35const char kMessageType[] = "type";
36
37// The name of the socket to listen for gnubby requests on.
38base::LazyInstance<base::FilePath>::Leaky g_gnubby_socket_name =
39    LAZY_INSTANCE_INITIALIZER;
40
41// STL predicate to match by a StreamListenSocket pointer.
42class CompareSocket {
43 public:
44  explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {}
45
46  bool operator()(const std::pair<int, GnubbySocket*> element) const {
47    return element.second->IsSocket(socket_);
48  }
49
50 private:
51  net::StreamListenSocket* socket_;
52};
53
54// Socket authentication function that only allows connections from callers with
55// the current uid.
56bool MatchUid(const net::UnixDomainServerSocket::Credentials& credentials) {
57  bool allowed = credentials.user_id == getuid();
58  if (!allowed)
59    HOST_LOG << "Refused socket connection from uid " << credentials.user_id;
60  return allowed;
61}
62
63// Returns the command code (the first byte of the data) if it exists, or -1 if
64// the data is empty.
65unsigned int GetCommandCode(const std::string& data) {
66  return data.empty() ? -1 : static_cast<unsigned int>(data[0]);
67}
68
69// Creates a string of byte data from a ListValue of numbers. Returns true if
70// all of the list elements are numbers.
71bool ConvertListValueToString(base::ListValue* bytes, std::string* out) {
72  out->clear();
73
74  unsigned int byte_count = bytes->GetSize();
75  if (byte_count != 0) {
76    out->reserve(byte_count);
77    for (unsigned int i = 0; i < byte_count; i++) {
78      int value;
79      if (!bytes->GetInteger(i, &value))
80        return false;
81      out->push_back(static_cast<char>(value));
82    }
83  }
84  return true;
85}
86
87}  // namespace
88
89GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
90    protocol::ClientStub* client_stub)
91    : client_stub_(client_stub), last_connection_id_(0) {
92  DCHECK(client_stub_);
93}
94
95GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
96  STLDeleteValues(&active_sockets_);
97}
98
99// static
100scoped_ptr<GnubbyAuthHandler> GnubbyAuthHandler::Create(
101    protocol::ClientStub* client_stub) {
102  return scoped_ptr<GnubbyAuthHandler>(new GnubbyAuthHandlerPosix(client_stub));
103}
104
105// static
106void GnubbyAuthHandler::SetGnubbySocketName(
107    const base::FilePath& gnubby_socket_name) {
108  g_gnubby_socket_name.Get() = gnubby_socket_name;
109}
110
111void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) {
112  DCHECK(CalledOnValidThread());
113
114  scoped_ptr<base::Value> value(base::JSONReader::Read(message));
115  base::DictionaryValue* client_message;
116  if (value && value->GetAsDictionary(&client_message)) {
117    std::string type;
118    if (!client_message->GetString(kMessageType, &type)) {
119      LOG(ERROR) << "Invalid gnubby-auth message";
120      return;
121    }
122
123    if (type == kControlMessage) {
124      std::string option;
125      if (client_message->GetString(kControlOption, &option) &&
126          option == kGnubbyAuthV1) {
127        CreateAuthorizationSocket();
128      } else {
129        LOG(ERROR) << "Invalid gnubby-auth control option";
130      }
131    } else if (type == kDataMessage) {
132      ActiveSockets::iterator iter = GetSocketForMessage(client_message);
133      if (iter != active_sockets_.end()) {
134        base::ListValue* bytes;
135        std::string response;
136        if (client_message->GetList(kDataPayload, &bytes) &&
137            ConvertListValueToString(bytes, &response)) {
138          HOST_LOG << "Sending gnubby response: " << GetCommandCode(response);
139          iter->second->SendResponse(response);
140        } else {
141          LOG(ERROR) << "Invalid gnubby data";
142          SendErrorAndCloseActiveSocket(iter);
143        }
144      } else {
145        LOG(ERROR) << "Unknown gnubby-auth data connection";
146      }
147    } else if (type == kErrorMessage) {
148      ActiveSockets::iterator iter = GetSocketForMessage(client_message);
149      if (iter != active_sockets_.end()) {
150        HOST_LOG << "Sending gnubby error";
151        SendErrorAndCloseActiveSocket(iter);
152      } else {
153        LOG(ERROR) << "Unknown gnubby-auth error connection";
154      }
155    } else {
156      LOG(ERROR) << "Unknown gnubby-auth message type: " << type;
157    }
158  }
159}
160
161void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
162    int connection_id,
163    const std::string& data) const {
164  DCHECK(CalledOnValidThread());
165
166  base::DictionaryValue request;
167  request.SetString(kMessageType, kDataMessage);
168  request.SetInteger(kConnectionId, connection_id);
169
170  base::ListValue* bytes = new base::ListValue();
171  for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) {
172    bytes->AppendInteger(static_cast<unsigned char>(*i));
173  }
174  request.Set(kDataPayload, bytes);
175
176  std::string request_json;
177  if (!base::JSONWriter::Write(&request, &request_json)) {
178    LOG(ERROR) << "Failed to create request json";
179    return;
180  }
181
182  protocol::ExtensionMessage message;
183  message.set_type(kGnubbyAuthMessage);
184  message.set_data(request_json);
185
186  client_stub_->DeliverHostMessage(message);
187}
188
189bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting(
190    net::StreamListenSocket* socket) const {
191  return std::find_if(active_sockets_.begin(),
192                      active_sockets_.end(),
193                      CompareSocket(socket)) != active_sockets_.end();
194}
195
196int GnubbyAuthHandlerPosix::GetConnectionIdForTesting(
197    net::StreamListenSocket* socket) const {
198  ActiveSockets::const_iterator iter = std::find_if(
199      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
200  return iter->first;
201}
202
203GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting(
204    net::StreamListenSocket* socket) const {
205  ActiveSockets::const_iterator iter = std::find_if(
206      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
207  return iter->second;
208}
209
210void GnubbyAuthHandlerPosix::DidAccept(
211    net::StreamListenSocket* server,
212    scoped_ptr<net::StreamListenSocket> socket) {
213  DCHECK(CalledOnValidThread());
214
215  int connection_id = ++last_connection_id_;
216  active_sockets_[connection_id] =
217      new GnubbySocket(socket.Pass(),
218                       base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut,
219                                  base::Unretained(this),
220                                  connection_id));
221}
222
223void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket,
224                                     const char* data,
225                                     int len) {
226  DCHECK(CalledOnValidThread());
227
228  ActiveSockets::iterator iter = std::find_if(
229      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
230  if (iter != active_sockets_.end()) {
231    GnubbySocket* gnubby_socket = iter->second;
232    gnubby_socket->AddRequestData(data, len);
233    if (gnubby_socket->IsRequestTooLarge()) {
234      SendErrorAndCloseActiveSocket(iter);
235    } else if (gnubby_socket->IsRequestComplete()) {
236      std::string request_data;
237      gnubby_socket->GetAndClearRequestData(&request_data);
238      ProcessGnubbyRequest(iter->first, request_data);
239    }
240  } else {
241    LOG(ERROR) << "Received data for unknown connection";
242  }
243}
244
245void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) {
246  DCHECK(CalledOnValidThread());
247
248  ActiveSockets::iterator iter = std::find_if(
249      active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
250  if (iter != active_sockets_.end()) {
251    delete iter->second;
252    active_sockets_.erase(iter);
253  }
254}
255
256void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
257  DCHECK(CalledOnValidThread());
258
259  if (!g_gnubby_socket_name.Get().empty()) {
260    // If the file already exists, a socket in use error is returned.
261    base::DeleteFile(g_gnubby_socket_name.Get(), false);
262
263    HOST_LOG << "Listening for gnubby requests on "
264             << g_gnubby_socket_name.Get().value();
265
266    auth_socket_ = net::deprecated::UnixDomainListenSocket::CreateAndListen(
267        g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid));
268    if (!auth_socket_.get()) {
269      LOG(ERROR) << "Failed to open socket for gnubby requests";
270    }
271  } else {
272    HOST_LOG << "No gnubby socket name specified";
273  }
274}
275
276void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
277    int connection_id,
278    const std::string& request_data) {
279  HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data);
280  DeliverHostDataMessage(connection_id, request_data);
281}
282
283GnubbyAuthHandlerPosix::ActiveSockets::iterator
284GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) {
285  int connection_id;
286  if (message->GetInteger(kConnectionId, &connection_id)) {
287    return active_sockets_.find(connection_id);
288  }
289  return active_sockets_.end();
290}
291
292void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
293    const ActiveSockets::iterator& iter) {
294  iter->second->SendSshError();
295
296  delete iter->second;
297  active_sockets_.erase(iter);
298}
299
300void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) {
301  HOST_LOG << "Gnubby request timed out";
302  ActiveSockets::iterator iter = active_sockets_.find(connection_id);
303  if (iter != active_sockets_.end())
304    SendErrorAndCloseActiveSocket(iter);
305}
306
307}  // namespace remoting
308