1/* Kernel module to match connection tracking byte counter.
2 * GPL (C) 2002 Martin Devera (devik@cdi.cz).
3 */
4#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5#include <linux/module.h>
6#include <linux/bitops.h>
7#include <linux/skbuff.h>
8#include <linux/math64.h>
9#include <linux/netfilter/x_tables.h>
10#include <linux/netfilter/xt_connbytes.h>
11#include <net/netfilter/nf_conntrack.h>
12#include <net/netfilter/nf_conntrack_acct.h>
13
14MODULE_LICENSE("GPL");
15MODULE_AUTHOR("Harald Welte <laforge@netfilter.org>");
16MODULE_DESCRIPTION("Xtables: Number of packets/bytes per connection matching");
17MODULE_ALIAS("ipt_connbytes");
18MODULE_ALIAS("ip6t_connbytes");
19
20static bool
21connbytes_mt(const struct sk_buff *skb, struct xt_action_param *par)
22{
23	const struct xt_connbytes_info *sinfo = par->matchinfo;
24	const struct nf_conn *ct;
25	enum ip_conntrack_info ctinfo;
26	u_int64_t what = 0;	/* initialize to make gcc happy */
27	u_int64_t bytes = 0;
28	u_int64_t pkts = 0;
29	const struct nf_conn_acct *acct;
30	const struct nf_conn_counter *counters;
31
32	ct = nf_ct_get(skb, &ctinfo);
33	if (!ct)
34		return false;
35
36	acct = nf_conn_acct_find(ct);
37	if (!acct)
38		return false;
39
40	counters = acct->counter;
41	switch (sinfo->what) {
42	case XT_CONNBYTES_PKTS:
43		switch (sinfo->direction) {
44		case XT_CONNBYTES_DIR_ORIGINAL:
45			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets);
46			break;
47		case XT_CONNBYTES_DIR_REPLY:
48			what = atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
49			break;
50		case XT_CONNBYTES_DIR_BOTH:
51			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets);
52			what += atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
53			break;
54		}
55		break;
56	case XT_CONNBYTES_BYTES:
57		switch (sinfo->direction) {
58		case XT_CONNBYTES_DIR_ORIGINAL:
59			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes);
60			break;
61		case XT_CONNBYTES_DIR_REPLY:
62			what = atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
63			break;
64		case XT_CONNBYTES_DIR_BOTH:
65			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes);
66			what += atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
67			break;
68		}
69		break;
70	case XT_CONNBYTES_AVGPKT:
71		switch (sinfo->direction) {
72		case XT_CONNBYTES_DIR_ORIGINAL:
73			bytes = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes);
74			pkts  = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets);
75			break;
76		case XT_CONNBYTES_DIR_REPLY:
77			bytes = atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
78			pkts  = atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
79			break;
80		case XT_CONNBYTES_DIR_BOTH:
81			bytes = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes) +
82				atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
83			pkts  = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets) +
84				atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
85			break;
86		}
87		if (pkts != 0)
88			what = div64_u64(bytes, pkts);
89		break;
90	}
91
92	if (sinfo->count.to >= sinfo->count.from)
93		return what <= sinfo->count.to && what >= sinfo->count.from;
94	else /* inverted */
95		return what < sinfo->count.to || what > sinfo->count.from;
96}
97
98static int connbytes_mt_check(const struct xt_mtchk_param *par)
99{
100	const struct xt_connbytes_info *sinfo = par->matchinfo;
101	int ret;
102
103	if (sinfo->what != XT_CONNBYTES_PKTS &&
104	    sinfo->what != XT_CONNBYTES_BYTES &&
105	    sinfo->what != XT_CONNBYTES_AVGPKT)
106		return -EINVAL;
107
108	if (sinfo->direction != XT_CONNBYTES_DIR_ORIGINAL &&
109	    sinfo->direction != XT_CONNBYTES_DIR_REPLY &&
110	    sinfo->direction != XT_CONNBYTES_DIR_BOTH)
111		return -EINVAL;
112
113	ret = nf_ct_l3proto_try_module_get(par->family);
114	if (ret < 0)
115		pr_info("cannot load conntrack support for proto=%u\n",
116			par->family);
117
118	/*
119	 * This filter cannot function correctly unless connection tracking
120	 * accounting is enabled, so complain in the hope that someone notices.
121	 */
122	if (!nf_ct_acct_enabled(par->net)) {
123		pr_warn("Forcing CT accounting to be enabled\n");
124		nf_ct_set_acct(par->net, true);
125	}
126
127	return ret;
128}
129
130static void connbytes_mt_destroy(const struct xt_mtdtor_param *par)
131{
132	nf_ct_l3proto_module_put(par->family);
133}
134
135static struct xt_match connbytes_mt_reg __read_mostly = {
136	.name       = "connbytes",
137	.revision   = 0,
138	.family     = NFPROTO_UNSPEC,
139	.checkentry = connbytes_mt_check,
140	.match      = connbytes_mt,
141	.destroy    = connbytes_mt_destroy,
142	.matchsize  = sizeof(struct xt_connbytes_info),
143	.me         = THIS_MODULE,
144};
145
146static int __init connbytes_mt_init(void)
147{
148	return xt_register_match(&connbytes_mt_reg);
149}
150
151static void __exit connbytes_mt_exit(void)
152{
153	xt_unregister_match(&connbytes_mt_reg);
154}
155
156module_init(connbytes_mt_init);
157module_exit(connbytes_mt_exit);
158