1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.net.util;
18
19import java.net.Inet6Address;
20import java.net.InetAddress;
21import java.nio.BufferOverflowException;
22import java.nio.BufferUnderflowException;
23import java.nio.ByteBuffer;
24import java.nio.ShortBuffer;
25
26import static android.system.OsConstants.IPPROTO_TCP;
27import static android.system.OsConstants.IPPROTO_UDP;
28
29/**
30 * @hide
31 */
32public class IpUtils {
33    /**
34     * Converts a signed short value to an unsigned int value.  Needed
35     * because Java does not have unsigned types.
36     */
37    private static int intAbs(short v) {
38        return v & 0xFFFF;
39    }
40
41    /**
42     * Performs an IP checksum (used in IP header and across UDP
43     * payload) on the specified portion of a ByteBuffer.  The seed
44     * allows the checksum to commence with a specified value.
45     */
46    private static int checksum(ByteBuffer buf, int seed, int start, int end) {
47        int sum = seed;
48        final int bufPosition = buf.position();
49
50        // set position of original ByteBuffer, so that the ShortBuffer
51        // will be correctly initialized
52        buf.position(start);
53        ShortBuffer shortBuf = buf.asShortBuffer();
54
55        // re-set ByteBuffer position
56        buf.position(bufPosition);
57
58        final int numShorts = (end - start) / 2;
59        for (int i = 0; i < numShorts; i++) {
60            sum += intAbs(shortBuf.get(i));
61        }
62        start += numShorts * 2;
63
64        // see if a singleton byte remains
65        if (end != start) {
66            short b = buf.get(start);
67
68            // make it unsigned
69            if (b < 0) {
70                b += 256;
71            }
72
73            sum += b * 256;
74        }
75
76        sum = ((sum >> 16) & 0xFFFF) + (sum & 0xFFFF);
77        sum = ((sum + ((sum >> 16) & 0xFFFF)) & 0xFFFF);
78        int negated = ~sum;
79        return intAbs((short) negated);
80    }
81
82    private static int pseudoChecksumIPv4(
83            ByteBuffer buf, int headerOffset, int protocol, int transportLen) {
84        int partial = protocol + transportLen;
85        partial += intAbs(buf.getShort(headerOffset + 12));
86        partial += intAbs(buf.getShort(headerOffset + 14));
87        partial += intAbs(buf.getShort(headerOffset + 16));
88        partial += intAbs(buf.getShort(headerOffset + 18));
89        return partial;
90    }
91
92    private static int pseudoChecksumIPv6(
93            ByteBuffer buf, int headerOffset, int protocol, int transportLen) {
94        int partial = protocol + transportLen;
95        for (int offset = 8; offset < 40; offset += 2) {
96            partial += intAbs(buf.getShort(headerOffset + offset));
97        }
98        return partial;
99    }
100
101    private static byte ipversion(ByteBuffer buf, int headerOffset) {
102        return (byte) ((buf.get(headerOffset) & (byte) 0xf0) >> 4);
103   }
104
105    public static short ipChecksum(ByteBuffer buf, int headerOffset) {
106        byte ihl = (byte) (buf.get(headerOffset) & 0x0f);
107        return (short) checksum(buf, 0, headerOffset, headerOffset + ihl * 4);
108    }
109
110    private static short transportChecksum(ByteBuffer buf, int protocol,
111            int ipOffset, int transportOffset, int transportLen) {
112        if (transportLen < 0) {
113            throw new IllegalArgumentException("Transport length < 0: " + transportLen);
114        }
115        int sum;
116        byte ver = ipversion(buf, ipOffset);
117        if (ver == 4) {
118            sum = pseudoChecksumIPv4(buf, ipOffset, protocol, transportLen);
119        } else if (ver == 6) {
120            sum = pseudoChecksumIPv6(buf, ipOffset, protocol, transportLen);
121        } else {
122            throw new UnsupportedOperationException("Checksum must be IPv4 or IPv6");
123        }
124
125        sum = checksum(buf, sum, transportOffset, transportOffset + transportLen);
126        if (protocol == IPPROTO_UDP && sum == 0) {
127            sum = (short) 0xffff;
128        }
129        return (short) sum;
130    }
131
132    public static short udpChecksum(ByteBuffer buf, int ipOffset, int transportOffset) {
133        int transportLen = intAbs(buf.getShort(transportOffset + 4));
134        return transportChecksum(buf, IPPROTO_UDP, ipOffset, transportOffset, transportLen);
135    }
136
137    public static short tcpChecksum(ByteBuffer buf, int ipOffset, int transportOffset,
138            int transportLen) {
139        return transportChecksum(buf, IPPROTO_TCP, ipOffset, transportOffset, transportLen);
140    }
141
142    public static String addressAndPortToString(InetAddress address, int port) {
143        return String.format(
144                (address instanceof Inet6Address) ? "[%s]:%d" : "%s:%d",
145                address.getHostAddress(), port);
146    }
147
148    public static boolean isValidUdpOrTcpPort(int port) {
149        return port > 0 && port < 65536;
150    }
151}
152