1#!/usr/bin/python
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Partial Python implementation of sock_diag functionality."""
18
19# pylint: disable=g-bad-todo
20
21import errno
22from socket import *  # pylint: disable=wildcard-import
23import struct
24
25import cstruct
26import net_test
27import netlink
28
29### Base netlink constants. See include/uapi/linux/netlink.h.
30NETLINK_SOCK_DIAG = 4
31
32### sock_diag constants. See include/uapi/linux/sock_diag.h.
33# Message types.
34SOCK_DIAG_BY_FAMILY = 20
35SOCK_DESTROY = 21
36
37### inet_diag_constants. See include/uapi/linux/inet_diag.h
38# Message types.
39TCPDIAG_GETSOCK = 18
40
41# Request attributes.
42INET_DIAG_REQ_BYTECODE = 1
43
44# Extensions.
45INET_DIAG_NONE = 0
46INET_DIAG_MEMINFO = 1
47INET_DIAG_INFO = 2
48INET_DIAG_VEGASINFO = 3
49INET_DIAG_CONG = 4
50INET_DIAG_TOS = 5
51INET_DIAG_TCLASS = 6
52INET_DIAG_SKMEMINFO = 7
53INET_DIAG_SHUTDOWN = 8
54INET_DIAG_DCTCPINFO = 9
55
56# Bytecode operations.
57INET_DIAG_BC_NOP = 0
58INET_DIAG_BC_JMP = 1
59INET_DIAG_BC_S_GE = 2
60INET_DIAG_BC_S_LE = 3
61INET_DIAG_BC_D_GE = 4
62INET_DIAG_BC_D_LE = 5
63INET_DIAG_BC_AUTO = 6
64INET_DIAG_BC_S_COND = 7
65INET_DIAG_BC_D_COND = 8
66
67# Data structure formats.
68# These aren't constants, they're classes. So, pylint: disable=invalid-name
69InetDiagSockId = cstruct.Struct(
70    "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie")
71InetDiagReqV2 = cstruct.Struct(
72    "InetDiagReqV2", "=BBBxIS", "family protocol ext states id",
73    [InetDiagSockId])
74InetDiagMsg = cstruct.Struct(
75    "InetDiagMsg", "=BBBBSLLLLL",
76    "family state timer retrans id expires rqueue wqueue uid inode",
77    [InetDiagSockId])
78InetDiagMeminfo = cstruct.Struct(
79    "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
80InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
81InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
82                                  "family prefix_len port")
83
84SkMeminfo = cstruct.Struct(
85    "SkMeminfo", "=IIIIIIII",
86    "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog")
87TcpInfo = cstruct.Struct(
88    "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII",
89    "state ca_state retransmits probes backoff options wscale "
90    "rto ato snd_mss rcv_mss "
91    "unacked sacked lost retrans fackets "
92    "last_data_sent last_ack_sent last_data_recv last_ack_recv "
93    "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering "
94    "rcv_rtt rcv_space "
95    "total_retrans")  # As of linux 3.13, at least.
96
97TCP_TIME_WAIT = 6
98ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT)
99
100
101class SockDiag(netlink.NetlinkSocket):
102
103  FAMILY = NETLINK_SOCK_DIAG
104  NL_DEBUG = []
105
106  def _Decode(self, command, msg, nla_type, nla_data):
107    """Decodes netlink attributes to Python types."""
108    if msg.family == AF_INET or msg.family == AF_INET6:
109      name = self._GetConstantName(__name__, nla_type, "INET_DIAG")
110    else:
111      # Don't know what this is. Leave it as an integer.
112      name = nla_type
113
114    if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS"]:
115      data = ord(nla_data)
116    elif name == "INET_DIAG_CONG":
117      data = nla_data.strip("\x00")
118    elif name == "INET_DIAG_MEMINFO":
119      data = InetDiagMeminfo(nla_data)
120    elif name == "INET_DIAG_INFO":
121      # TODO: Catch the exception and try something else if it's not TCP.
122      data = TcpInfo(nla_data)
123    elif name == "INET_DIAG_SKMEMINFO":
124      data = SkMeminfo(nla_data)
125    else:
126      data = nla_data
127
128    return name, data
129
130  def MaybeDebugCommand(self, command, data):
131    name = self._GetConstantName(__name__, command, "SOCK_")
132    if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG:
133      return
134    parsed = self._ParseNLMsg(data, InetDiagReqV2)
135    print "%s %s" % (name, str(parsed))
136
137  @staticmethod
138  def _EmptyInetDiagSockId():
139    return InetDiagSockId(("\x00" * len(InetDiagSockId)))
140
141  def PackBytecode(self, instructions):
142    """Compiles instructions to inet_diag bytecode.
143
144    The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
145    and no are relative jump offsets measured in instructions. The yes branch
146    is taken if the instruction matches.
147
148    To accept, jump 1 past the last instruction. To reject, jump 2 past the
149    last instruction.
150
151    The target of a no jump is only valid if it is reachable by following
152    only yes jumps from the first instruction - see inet_diag_bc_audit and
153    valid_cc. This means that if cond1 and cond2 are two mutually exclusive
154    filter terms, it is not possible to implement cond1 OR cond2 using:
155
156      ...
157      cond1 2 1 arg
158      cond2 1 2 arg
159      accept
160      reject
161
162    but only using:
163
164      ...
165      cond1 1 2 arg
166      jmp   1 2
167      cond2 1 2 arg
168      accept
169      reject
170
171    The jmp instruction ignores yes and always jumps to no, but yes must be 1
172    or the bytecode won't validate. It doesn't have to be jmp - any instruction
173    that is guaranteed not to match on real data will do.
174
175    Args:
176      instructions: list of instruction tuples
177
178    Returns:
179      A string, the raw bytecode.
180    """
181    args = []
182    positions = [0]
183
184    for op, yes, no, arg in instructions:
185
186      if yes <= 0 or no <= 0:
187        raise ValueError("Jumps must be > 0")
188
189      if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
190        arg = ""
191      elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
192                  INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
193        arg = "\x00\x00" + struct.pack("=H", arg)
194      elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
195        addr, prefixlen, port = arg
196        family = AF_INET6 if ":" in addr else AF_INET
197        addr = inet_pton(family, addr)
198        arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
199      else:
200        raise ValueError("Unsupported opcode %d" % op)
201
202      args.append(arg)
203      length = len(InetDiagBcOp) + len(arg)
204      positions.append(positions[-1] + length)
205
206    # Reject label.
207    positions.append(positions[-1] + 4)  # Why 4? Because the kernel uses 4.
208    assert len(args) == len(instructions) == len(positions) - 2
209
210    # print positions
211
212    packed = ""
213    for i, (op, yes, no, arg) in enumerate(instructions):
214      yes = positions[i + yes] - positions[i]
215      no = positions[i + no] - positions[i]
216      instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
217      #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
218      #                                 arg, instruction.encode("hex"))
219      packed += instruction
220    #print
221
222    return packed
223
224  def Dump(self, diag_req, bytecode=""):
225    out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
226    return out
227
228  def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
229                         states=ALL_NON_TIME_WAIT):
230    """Dumps IPv4 or IPv6 sockets matching the specified parameters."""
231    # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
232    # results in ENOENT.
233    if sock_id is None:
234      sock_id = self._EmptyInetDiagSockId()
235
236    if bytecode:
237      bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
238
239    sockets = []
240    for family in [AF_INET, AF_INET6]:
241      diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
242      sockets += self.Dump(diag_req, bytecode)
243
244    return sockets
245
246  @staticmethod
247  def GetRawAddress(family, addr):
248    """Fetches the source address from an InetDiagMsg."""
249    addrlen = {AF_INET:4, AF_INET6: 16}[family]
250    return inet_ntop(family, addr[:addrlen])
251
252  @staticmethod
253  def GetSourceAddress(diag_msg):
254    """Fetches the source address from an InetDiagMsg."""
255    return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src)
256
257  @staticmethod
258  def GetDestinationAddress(diag_msg):
259    """Fetches the source address from an InetDiagMsg."""
260    return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst)
261
262  @staticmethod
263  def RawAddress(addr):
264    """Converts an IP address string to binary format."""
265    family = AF_INET6 if ":" in addr else AF_INET
266    return inet_pton(family, addr)
267
268  @staticmethod
269  def PaddedAddress(addr):
270    """Converts an IP address string to binary format for InetDiagSockId."""
271    padded = SockDiag.RawAddress(addr)
272    if len(padded) < 16:
273      padded += "\x00" * (16 - len(padded))
274    return padded
275
276  @staticmethod
277  def DiagReqFromSocket(s):
278    """Creates an InetDiagReqV2 that matches the specified socket."""
279    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
280    protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
281    if net_test.LINUX_VERSION >= (3, 8):
282      iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE,
283                           net_test.IFNAMSIZ)
284      iface = GetInterfaceIndex(iface) if iface else 0
285    else:
286      iface = 0
287    src, sport = s.getsockname()[:2]
288    try:
289      dst, dport = s.getpeername()[:2]
290    except error, e:
291      if e.errno == errno.ENOTCONN:
292        dport = 0
293        dst = "::" if family == AF_INET6 else "0.0.0.0"
294      else:
295        raise e
296    src = SockDiag.PaddedAddress(src)
297    dst = SockDiag.PaddedAddress(dst)
298    sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
299    return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
300
301  def FindSockDiagFromReq(self, req):
302    for diag_msg, attrs in self.Dump(req):
303      return diag_msg
304    raise ValueError("Dump of %s returned no sockets" % req)
305
306  def FindSockDiagFromFd(self, s):
307    """Gets an InetDiagMsg from the kernel for the specified socket."""
308    req = self.DiagReqFromSocket(s)
309    return self.FindSockDiagFromReq(req)
310
311  def GetSockDiag(self, req):
312    """Gets an InetDiagMsg from the kernel for the specified request."""
313    self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
314    return self._GetMsg(InetDiagMsg)[0]
315
316  @staticmethod
317  def DiagReqFromDiagMsg(d, protocol):
318    """Constructs a diag_req from a diag_msg the kernel has given us."""
319    return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
320
321  def CloseSocket(self, req):
322    self._SendNlRequest(SOCK_DESTROY, req.Pack(),
323                        netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
324
325  def CloseSocketFromFd(self, s):
326    diag_msg = self.FindSockDiagFromFd(s)
327    protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
328    req = self.DiagReqFromDiagMsg(diag_msg, protocol)
329    return self.CloseSocket(req)
330
331
332if __name__ == "__main__":
333  n = SockDiag()
334  n.DEBUG = True
335  bytecode = ""
336  sock_id = n._EmptyInetDiagSockId()
337  sock_id.dport = 443
338  ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
339  states = 0xffffffff
340  diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "",
341                                   sock_id=sock_id, ext=ext, states=states)
342  print diag_msgs
343