1//
2// Copyright (C) 2012 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "shill/arp_client.h"
18
19#include <linux/if_packet.h>
20#include <net/ethernet.h>
21#include <net/if_arp.h>
22#include <netinet/in.h>
23#include <string.h>
24
25#include "shill/arp_packet.h"
26#include "shill/logging.h"
27#include "shill/net/byte_string.h"
28#include "shill/net/sockets.h"
29
30namespace shill {
31
32// ARP opcode is the last uint16_t in the ARP header.
33const size_t ArpClient::kArpOpOffset = sizeof(arphdr) - sizeof(uint16_t);
34
35// The largest packet we expect is one with IPv6 addresses in it.
36const size_t ArpClient::kMaxArpPacketLength =
37    sizeof(arphdr) + sizeof(in6_addr) * 2 + ETH_ALEN * 2;
38
39ArpClient::ArpClient(int interface_index)
40    : interface_index_(interface_index),
41      sockets_(new Sockets()),
42      socket_(-1) {}
43
44ArpClient::~ArpClient() {}
45
46bool ArpClient::StartReplyListener() {
47  return Start(ARPOP_REPLY);
48}
49
50bool ArpClient::StartRequestListener() {
51  return Start(ARPOP_REQUEST);
52}
53
54bool ArpClient::Start(uint16_t arp_opcode) {
55  if (!CreateSocket(arp_opcode)) {
56    LOG(ERROR) << "Could not open ARP socket.";
57    Stop();
58    return false;
59  }
60  return true;
61}
62
63void ArpClient::Stop() {
64  socket_closer_.reset();
65}
66
67
68bool ArpClient::CreateSocket(uint16_t arp_opcode) {
69  int socket = sockets_->Socket(PF_PACKET, SOCK_DGRAM, htons(ETHERTYPE_ARP));
70  if (socket == -1) {
71    PLOG(ERROR) << "Could not create ARP socket";
72    return false;
73  }
74  socket_ = socket;
75  socket_closer_.reset(new ScopedSocketCloser(sockets_.get(), socket_));
76
77  // Create a packet filter incoming ARP packets.
78  const sock_filter arp_filter[] = {
79    // If a packet contains the ARP opcode we are looking for...
80    BPF_STMT(BPF_LD | BPF_H | BPF_ABS, kArpOpOffset),
81    BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, arp_opcode, 0, 1),
82    // Return the the packet (up to largest expected packet size).
83    BPF_STMT(BPF_RET | BPF_K, kMaxArpPacketLength),
84    // Otherwise, drop it.
85    BPF_STMT(BPF_RET | BPF_K, 0),
86  };
87
88  sock_fprog pf;
89  pf.filter = const_cast<sock_filter*>(arp_filter);
90  pf.len = arraysize(arp_filter);
91  if (sockets_->AttachFilter(socket_, &pf) != 0) {
92    PLOG(ERROR) << "Could not attach packet filter";
93    return false;
94  }
95
96  if (sockets_->SetNonBlocking(socket_) != 0) {
97    PLOG(ERROR) << "Could not set socket to be non-blocking";
98    return false;
99  }
100
101  sockaddr_ll socket_address;
102  memset(&socket_address, 0, sizeof(socket_address));
103  socket_address.sll_family = AF_PACKET;
104  socket_address.sll_protocol = htons(ETHERTYPE_ARP);
105  socket_address.sll_ifindex = interface_index_;
106
107  if (sockets_->Bind(socket_,
108                     reinterpret_cast<struct sockaddr*>(&socket_address),
109                     sizeof(socket_address)) != 0) {
110    PLOG(ERROR) << "Could not bind socket to interface";
111    return false;
112  }
113
114  return true;
115}
116
117bool ArpClient::ReceivePacket(ArpPacket* packet, ByteString* sender) const {
118  ByteString payload(kMaxArpPacketLength);
119  sockaddr_ll socket_address;
120  memset(&socket_address, 0, sizeof(socket_address));
121  socklen_t socklen = sizeof(socket_address);
122  int result = sockets_->RecvFrom(
123      socket_,
124      payload.GetData(),
125      payload.GetLength(),
126      0,
127      reinterpret_cast<struct sockaddr*>(&socket_address),
128      &socklen);
129  if (result < 0) {
130    PLOG(ERROR) << "Socket recvfrom failed";
131    return false;
132  }
133
134  payload.Resize(result);
135  if (!packet->Parse(payload)) {
136    LOG(ERROR) << "Failed to parse ARP packet.";
137    return false;
138  }
139
140  // The socket address returned may only be big enough to contain
141  // the hardware address of the sender.
142  CHECK(socklen >=
143        sizeof(socket_address) - sizeof(socket_address.sll_addr) + ETH_ALEN);
144  CHECK(socket_address.sll_halen == ETH_ALEN);
145  *sender = ByteString(
146      reinterpret_cast<const unsigned char*>(&socket_address.sll_addr),
147      socket_address.sll_halen);
148  return true;
149}
150
151bool ArpClient::TransmitRequest(const ArpPacket& packet) const {
152  ByteString payload;
153  if (!packet.FormatRequest(&payload)) {
154    return false;
155  }
156
157  sockaddr_ll socket_address;
158  memset(&socket_address, 0, sizeof(socket_address));
159  socket_address.sll_family = AF_PACKET;
160  socket_address.sll_protocol = htons(ETHERTYPE_ARP);
161  socket_address.sll_hatype = ARPHRD_ETHER;
162  socket_address.sll_halen = ETH_ALEN;
163  socket_address.sll_ifindex = interface_index_;
164
165  ByteString remote_address = packet.remote_mac_address();
166  CHECK(sizeof(socket_address.sll_addr) >= remote_address.GetLength());
167  if (remote_address.IsZero()) {
168    // If the destination MAC address is unspecified, send the packet
169    // to the broadcast (all-ones) address.
170    remote_address.BitwiseInvert();
171  }
172  memcpy(&socket_address.sll_addr, remote_address.GetConstData(),
173         remote_address.GetLength());
174
175  int result = sockets_->SendTo(
176      socket_,
177      payload.GetConstData(),
178      payload.GetLength(),
179      0,
180      reinterpret_cast<struct sockaddr*>(&socket_address),
181      sizeof(socket_address));
182  const int expected_result  = static_cast<int>(payload.GetLength());
183  if (result != expected_result) {
184    if (result < 0) {
185      PLOG(ERROR) << "Socket sendto failed";
186    } else if (result < static_cast<int>(payload.GetLength())) {
187      LOG(ERROR) << "Socket sendto returned "
188                 << result
189                 << " which is different from expected result "
190                 << expected_result;
191    }
192    return false;
193  }
194
195  return true;
196}
197
198}  // namespace shill
199