1/*
2 * Copyright 2008, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <stdlib.h>
18#include <string.h>
19#include <unistd.h>
20#include <sys/uio.h>
21#include <sys/socket.h>
22#include <netinet/in.h>
23#include <netinet/ip.h>
24#include <netinet/udp.h>
25#include <linux/if_packet.h>
26#include <linux/if_ether.h>
27#include <errno.h>
28
29#ifdef ANDROID
30#define LOG_TAG "DHCP"
31#include <cutils/log.h>
32#else
33#include <stdio.h>
34#include <string.h>
35#define ALOGD printf
36#define ALOGW printf
37#endif
38
39#include "dhcpmsg.h"
40
41int fatal();
42
43int open_raw_socket(const char *ifname __attribute__((unused)), uint8_t *hwaddr, int if_index)
44{
45    int s;
46    struct sockaddr_ll bindaddr;
47
48    if((s = socket(PF_PACKET, SOCK_DGRAM, htons(ETH_P_IP))) < 0) {
49        return fatal("socket(PF_PACKET)");
50    }
51
52    memset(&bindaddr, 0, sizeof(bindaddr));
53    bindaddr.sll_family = AF_PACKET;
54    bindaddr.sll_protocol = htons(ETH_P_IP);
55    bindaddr.sll_halen = ETH_ALEN;
56    memcpy(bindaddr.sll_addr, hwaddr, ETH_ALEN);
57    bindaddr.sll_ifindex = if_index;
58
59    if (bind(s, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) < 0) {
60        return fatal("Cannot bind raw socket to interface");
61    }
62
63    return s;
64}
65
66static uint32_t checksum(void *buffer, unsigned int count, uint32_t startsum)
67{
68    uint16_t *up = (uint16_t *)buffer;
69    uint32_t sum = startsum;
70    uint32_t upper16;
71
72    while (count > 1) {
73        sum += *up++;
74        count -= 2;
75    }
76    if (count > 0) {
77        sum += (uint16_t) *(uint8_t *)up;
78    }
79    while ((upper16 = (sum >> 16)) != 0) {
80        sum = (sum & 0xffff) + upper16;
81    }
82    return sum;
83}
84
85static uint32_t finish_sum(uint32_t sum)
86{
87    return ~sum & 0xffff;
88}
89
90int send_packet(int s, int if_index, struct dhcp_msg *msg, int size,
91                uint32_t saddr, uint32_t daddr, uint32_t sport, uint32_t dport)
92{
93    struct iphdr ip;
94    struct udphdr udp;
95    struct iovec iov[3];
96    uint32_t udpsum;
97    uint16_t temp;
98    struct msghdr msghdr;
99    struct sockaddr_ll destaddr;
100
101    ip.version = IPVERSION;
102    ip.ihl = sizeof(ip) >> 2;
103    ip.tos = 0;
104    ip.tot_len = htons(sizeof(ip) + sizeof(udp) + size);
105    ip.id = 0;
106    ip.frag_off = 0;
107    ip.ttl = IPDEFTTL;
108    ip.protocol = IPPROTO_UDP;
109    ip.check = 0;
110    ip.saddr = saddr;
111    ip.daddr = daddr;
112    ip.check = finish_sum(checksum(&ip, sizeof(ip), 0));
113
114    udp.source = htons(sport);
115    udp.dest = htons(dport);
116    udp.len = htons(sizeof(udp) + size);
117    udp.check = 0;
118
119    /* Calculate checksum for pseudo header */
120    udpsum = checksum(&ip.saddr, sizeof(ip.saddr), 0);
121    udpsum = checksum(&ip.daddr, sizeof(ip.daddr), udpsum);
122    temp = htons(IPPROTO_UDP);
123    udpsum = checksum(&temp, sizeof(temp), udpsum);
124    temp = udp.len;
125    udpsum = checksum(&temp, sizeof(temp), udpsum);
126
127    /* Add in the checksum for the udp header */
128    udpsum = checksum(&udp, sizeof(udp), udpsum);
129
130    /* Add in the checksum for the data */
131    udpsum = checksum(msg, size, udpsum);
132    udp.check = finish_sum(udpsum);
133
134    iov[0].iov_base = (char *)&ip;
135    iov[0].iov_len = sizeof(ip);
136    iov[1].iov_base = (char *)&udp;
137    iov[1].iov_len = sizeof(udp);
138    iov[2].iov_base = (char *)msg;
139    iov[2].iov_len = size;
140    memset(&destaddr, 0, sizeof(destaddr));
141    destaddr.sll_family = AF_PACKET;
142    destaddr.sll_protocol = htons(ETH_P_IP);
143    destaddr.sll_ifindex = if_index;
144    destaddr.sll_halen = ETH_ALEN;
145    memcpy(destaddr.sll_addr, "\xff\xff\xff\xff\xff\xff", ETH_ALEN);
146
147    msghdr.msg_name = &destaddr;
148    msghdr.msg_namelen = sizeof(destaddr);
149    msghdr.msg_iov = iov;
150    msghdr.msg_iovlen = sizeof(iov) / sizeof(struct iovec);
151    msghdr.msg_flags = 0;
152    msghdr.msg_control = 0;
153    msghdr.msg_controllen = 0;
154    return sendmsg(s, &msghdr, 0);
155}
156
157int receive_packet(int s, struct dhcp_msg *msg)
158{
159    int nread;
160    int is_valid;
161    struct dhcp_packet {
162        struct iphdr ip;
163        struct udphdr udp;
164        struct dhcp_msg dhcp;
165    } packet;
166    int dhcp_size;
167    uint32_t sum;
168    uint16_t temp;
169    uint32_t saddr, daddr;
170
171    nread = read(s, &packet, sizeof(packet));
172    if (nread < 0) {
173        return -1;
174    }
175    /*
176     * The raw packet interface gives us all packets received by the
177     * network interface. We need to filter out all packets that are
178     * not meant for us.
179     */
180    is_valid = 0;
181    if (nread < (int)(sizeof(struct iphdr) + sizeof(struct udphdr))) {
182#if VERBOSE
183        ALOGD("Packet is too small (%d) to be a UDP datagram", nread);
184#endif
185    } else if (packet.ip.version != IPVERSION || packet.ip.ihl != (sizeof(packet.ip) >> 2)) {
186#if VERBOSE
187        ALOGD("Not a valid IP packet");
188#endif
189    } else if (nread < ntohs(packet.ip.tot_len)) {
190#if VERBOSE
191        ALOGD("Packet was truncated (read %d, needed %d)", nread, ntohs(packet.ip.tot_len));
192#endif
193    } else if (packet.ip.protocol != IPPROTO_UDP) {
194#if VERBOSE
195        ALOGD("IP protocol (%d) is not UDP", packet.ip.protocol);
196#endif
197    } else if (packet.udp.dest != htons(PORT_BOOTP_CLIENT)) {
198#if VERBOSE
199        ALOGD("UDP dest port (%d) is not DHCP client", ntohs(packet.udp.dest));
200#endif
201    } else {
202        is_valid = 1;
203    }
204
205    if (!is_valid) {
206        return -1;
207    }
208
209    /* Seems like it's probably a valid DHCP packet */
210    /* validate IP header checksum */
211    sum = finish_sum(checksum(&packet.ip, sizeof(packet.ip), 0));
212    if (sum != 0) {
213        ALOGW("IP header checksum failure (0x%x)", packet.ip.check);
214        return -1;
215    }
216    /*
217     * Validate the UDP checksum.
218     * Since we don't need the IP header anymore, we "borrow" it
219     * to construct the pseudo header used in the checksum calculation.
220     */
221    dhcp_size = ntohs(packet.udp.len) - sizeof(packet.udp);
222    saddr = packet.ip.saddr;
223    daddr = packet.ip.daddr;
224    nread = ntohs(packet.ip.tot_len);
225    memset(&packet.ip, 0, sizeof(packet.ip));
226    packet.ip.saddr = saddr;
227    packet.ip.daddr = daddr;
228    packet.ip.protocol = IPPROTO_UDP;
229    packet.ip.tot_len = packet.udp.len;
230    temp = packet.udp.check;
231    packet.udp.check = 0;
232    sum = finish_sum(checksum(&packet, nread, 0));
233    packet.udp.check = temp;
234    if (!sum)
235        sum = finish_sum(sum);
236    if (temp != sum) {
237        ALOGW("UDP header checksum failure (0x%x should be 0x%x)", sum, temp);
238        return -1;
239    }
240    memcpy(msg, &packet.dhcp, dhcp_size);
241    return dhcp_size;
242}
243