1//
2// Copyright (C) 2015 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "dhcp_client/dhcpv4.h"
18
19#include <linux/filter.h>
20#include <linux/if_packet.h>
21#include <net/ethernet.h>
22#include <net/if.h>
23#include <net/if_arp.h>
24#include <netinet/ip.h>
25#include <netinet/udp.h>
26
27#include <random>
28
29#include <base/bind.h>
30#include <base/logging.h>
31
32#include "dhcp_client/dhcp_message.h"
33
34using base::Bind;
35using base::Unretained;
36using shill::ByteString;
37using shill::IOHandlerFactoryContainer;
38
39namespace dhcp_client {
40
41namespace {
42// UDP port numbers for DHCP.
43const uint16_t kDHCPServerPort = 67;
44const uint16_t kDHCPClientPort = 68;
45
46const int kInvalidSocketDescriptor = -1;
47
48// RFC 791: the minimum value for a correct header is 20 octets.
49// The maximum value is 60 octets.
50const size_t kIPHeaderMinLength = 20;
51const size_t kIPHeaderMaxLength = 60;
52
53// Socket filter for dhcp packet.
54const sock_filter dhcp_bpf_filter[] = {
55  BPF_STMT(BPF_LD + BPF_B + BPF_ABS, 23 - ETH_HLEN),
56  BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, IPPROTO_UDP, 0, 6),
57  BPF_STMT(BPF_LD + BPF_H + BPF_ABS, 20 - ETH_HLEN),
58  BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, 0x1fff, 4, 0),
59  BPF_STMT(BPF_LDX + BPF_B + BPF_MSH, 14 - ETH_HLEN),
60  BPF_STMT(BPF_LD + BPF_H + BPF_IND, 16 - ETH_HLEN),
61  BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, kDHCPClientPort, 0, 1),
62  BPF_STMT(BPF_RET + BPF_K, 0x0fffffff),
63  BPF_STMT(BPF_RET + BPF_K, 0),
64};
65const int dhcp_bpf_filter_len =
66    sizeof(dhcp_bpf_filter) / sizeof(dhcp_bpf_filter[0]);
67}  // namespace
68
69DHCPV4::DHCPV4(const std::string& interface_name,
70               const ByteString& hardware_address,
71               unsigned int interface_index,
72               const std::string& network_id,
73               bool request_hostname,
74               bool arp_gateway,
75               bool unicast_arp,
76               EventDispatcherInterface* event_dispatcher)
77    : interface_name_(interface_name),
78      hardware_address_(hardware_address),
79      interface_index_(interface_index),
80      network_id_(network_id),
81      request_hostname_(request_hostname),
82      arp_gateway_(arp_gateway),
83      unicast_arp_(unicast_arp),
84      event_dispatcher_(event_dispatcher),
85      io_handler_factory_(
86          IOHandlerFactoryContainer::GetInstance()->GetIOHandlerFactory()),
87      state_(State::INIT),
88      from_(INADDR_ANY),
89      to_(INADDR_BROADCAST),
90      socket_(kInvalidSocketDescriptor),
91      sockets_(new shill::Sockets()),
92      random_engine_(time(nullptr)) {
93}
94
95DHCPV4::~DHCPV4() {
96  Stop();
97}
98
99void DHCPV4::ParseRawPacket(shill::InputData* data) {
100  if (data->len < sizeof(iphdr)) {
101    LOG(ERROR) << "Invalid packet length from buffer";
102    return;
103  }
104  // The socket filter has finished part the header validation.
105  // This function will perform the remaining part.
106  int header_len = ValidatePacketHeader(data->buf, data->len);
107  if (header_len == -1) {
108    return;
109  }
110  unsigned char* buffer = data->buf + header_len;
111  DHCPMessage msg;
112  if (!DHCPMessage::InitFromBuffer(buffer, data->len - header_len, &msg)) {
113    LOG(ERROR) << "Failed to initialize DHCP message from buffer";
114    return;
115  }
116  // In INIT state the client ignores all messages from server.
117  if (state_ == State::INIT) {
118    return;
119  }
120  // Check transaction id with the existing one.
121  if (msg.transaction_id() != transaction_id_) {
122    LOG(ERROR) << "Transaction id(xid) doesn't match";
123    return;
124  }
125  uint8_t message_type = msg.message_type();
126  switch (message_type) {
127    case kDHCPMessageTypeOffer:
128      HandleOffer(msg);
129      break;
130    case kDHCPMessageTypeAck:
131      HandleAck(msg);
132      break;
133    case kDHCPMessageTypeNak:
134      HandleNak(msg);
135      break;
136    default:
137      LOG(ERROR) << "Invalid message type: "
138                 << static_cast<int>(message_type);
139  }
140}
141
142void DHCPV4::OnReadError(const std::string& error_msg) {
143  LOG(INFO) << __func__;
144}
145
146bool DHCPV4::Start() {
147  if (!CreateRawSocket()) {
148    return false;
149  }
150
151  input_handler_.reset(io_handler_factory_->CreateIOInputHandler(
152      socket_,
153      Bind(&DHCPV4::ParseRawPacket, Unretained(this)),
154      Bind(&DHCPV4::OnReadError, Unretained(this))));
155  return true;
156}
157
158void DHCPV4::Stop() {
159  input_handler_.reset();
160  if (socket_ != kInvalidSocketDescriptor) {
161    sockets_->Close(socket_);
162  }
163}
164
165bool DHCPV4::CreateRawSocket() {
166  int fd = sockets_->Socket(PF_PACKET,
167                            SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK,
168                            htons(ETHERTYPE_IP));
169  if (fd == kInvalidSocketDescriptor) {
170    PLOG(ERROR) << "Failed to create socket";
171    return false;
172  }
173  shill::ScopedSocketCloser socket_closer(sockets_.get(), fd);
174
175  // Apply the socket filter.
176  sock_fprog pf;
177  memset(&pf, 0, sizeof(pf));
178  pf.filter = const_cast<sock_filter*>(dhcp_bpf_filter);
179  pf.len = dhcp_bpf_filter_len;
180
181  if (sockets_->AttachFilter(fd, &pf) != 0) {
182    PLOG(ERROR) << "Failed to attach filter";
183    return false;
184  }
185
186  if (sockets_->ReuseAddress(fd) == -1) {
187    PLOG(ERROR) << "Failed to reuse socket address";
188    return false;
189  }
190
191  if (sockets_->BindToDevice(fd, interface_name_) < 0) {
192    PLOG(ERROR) << "Failed to bind socket to device";
193    return false;
194  }
195
196  struct sockaddr_ll local;
197  memset(&local, 0, sizeof(local));
198  local.sll_family = PF_PACKET;
199  local.sll_protocol = htons(ETHERTYPE_IP);
200  local.sll_ifindex = static_cast<int>(interface_index_);
201
202  if (sockets_->Bind(fd,
203                     reinterpret_cast<struct sockaddr*>(&local),
204                     sizeof(local)) < 0) {
205    PLOG(ERROR) << "Failed to bind to address";
206    return false;
207  }
208
209  socket_ = socket_closer.Release();
210  return true;
211}
212
213void DHCPV4::HandleOffer(const DHCPMessage& msg) {
214  return;
215}
216
217void DHCPV4::HandleAck(const DHCPMessage& msg) {
218  return;
219}
220
221void DHCPV4::HandleNak(const DHCPMessage& msg) {
222  return;
223}
224
225bool DHCPV4::MakeRawPacket(const DHCPMessage& message, ByteString* output) {
226  ByteString payload;
227  if (!message.Serialize(&payload)) {
228    LOG(ERROR) << "Failed to serialzie dhcp message";
229    return false;
230  }
231  const size_t header_len = sizeof(struct iphdr) + sizeof(struct udphdr);
232  const size_t payload_len = payload.GetLength();
233
234  char buffer[header_len + payload_len];
235  memset(buffer, 0, header_len + payload_len);
236  struct iphdr* ip = reinterpret_cast<struct iphdr*>(buffer);
237  struct udphdr* udp = reinterpret_cast<struct udphdr*>(buffer + sizeof(*ip));
238
239  if (!payload.CopyData(payload_len, buffer + header_len)) {
240    LOG(ERROR) << "Failed to copy data from payload";
241    return false;
242  }
243  udp->uh_sport = htons(kDHCPClientPort);
244  udp->uh_dport = htons(kDHCPServerPort);
245  udp->uh_ulen =
246      htons(static_cast<uint16_t>(sizeof(*udp) + payload.GetLength()));
247
248  // Fill pseudo header (for UDP checksum computing):
249  // Protocol.
250  ip->protocol = IPPROTO_UDP;
251  // Source IP address.
252  ip->saddr = htonl(from_);
253  // Destination IP address.
254  ip->daddr = htonl(to_);
255  // Total length, use udp packet length for pseudo header.
256  ip->tot_len = udp->uh_ulen;
257  // Calculate udp checksum based on:
258  // IPV4 pseudo header, UDP header, and payload.
259  udp->uh_sum = htons(DHCPMessage::ComputeChecksum(
260      reinterpret_cast<const uint8_t*>(buffer),
261      header_len + payload_len));
262
263  // IP version.
264  ip->version = IPVERSION;
265  // IP header length.
266  ip->ihl = sizeof(*ip) >> 2;
267  // Fragment offset field.
268  // The DHCP packet is always smaller than MTU,
269  // so fragmentation is not needed.
270  ip->frag_off = 0;
271  // Identification.
272  ip->id = static_cast<uint16_t>(
273      std::uniform_int_distribution<unsigned int>()(
274          random_engine_) % UINT16_MAX + 1);
275  // Time to live.
276  ip->ttl = IPDEFTTL;
277  // Total length.
278  ip->tot_len = htons(static_cast<uint16_t>(header_len+ payload.GetLength()));
279  // Calculate IP Checksum only based on IP header.
280  ip->check = htons(DHCPMessage::ComputeChecksum(
281      reinterpret_cast<const uint8_t*>(ip),
282      sizeof(*ip)));
283
284  *output = ByteString(buffer, header_len + payload_len);
285  return true;
286}
287
288bool DHCPV4::SendRawPacket(const ByteString& packet) {
289  struct sockaddr_ll remote;
290  memset(&remote, 0, sizeof(remote));
291  remote.sll_family = AF_PACKET;
292  remote.sll_protocol = htons(ETHERTYPE_IP);
293  remote.sll_ifindex = interface_index_;
294  remote.sll_hatype = htons(ARPHRD_ETHER);
295  // Use broadcast hardware address.
296  remote.sll_halen = IFHWADDRLEN;
297  memset(remote.sll_addr, 0xff, IFHWADDRLEN);
298
299  size_t result = sockets_->SendTo(socket_,
300                                   packet.GetConstData(),
301                                   packet.GetLength(),
302                                   0,
303                                   reinterpret_cast<struct sockaddr *>(&remote),
304                                   sizeof(remote));
305
306  if (result != packet.GetLength()) {
307    PLOG(ERROR) << "Socket sento failed";
308    return false;
309  }
310  return true;
311}
312
313int DHCPV4::ValidatePacketHeader(const unsigned char* buffer, size_t len) {
314  const struct iphdr* ip =
315      reinterpret_cast<const struct iphdr*>(buffer);
316  const size_t ip_header_len = static_cast<size_t>(ip->ihl) << 2;
317  if (ip_header_len < kIPHeaderMinLength ||
318      ip_header_len > kIPHeaderMaxLength) {
319    LOG(ERROR) << "Invalid Internet Header Length: "
320               << ip_header_len << " bytes";
321    return -1;
322  }
323  if (ip->tot_len != len) {
324    LOG(ERROR) << "Invalid IP total length";
325    return -1;
326  }
327  // TODO(nywang): Validate other ip header fields.
328
329  const struct udphdr* udp =
330      reinterpret_cast<const struct udphdr*>(buffer + ip_header_len);
331  if (udp->uh_sport != htons(kDHCPServerPort) ||
332      udp->uh_dport != htons(kDHCPClientPort)) {
333    LOG(ERROR) << "Invlaid UDP ports";
334    return -1;
335  }
336  if (udp->uh_ulen != len - ip_header_len) {
337    LOG(ERROR) << "Invalid UDP total length";
338    return -1;
339  }
340  // TODO(nywang): Validate UDP checksum.
341
342  return ip_header_len + sizeof(*udp);
343}
344
345}  // namespace dhcp_client
346
347