ip_set_core.c revision 9076aea76538556224e7d73ab718f8841330818a
1/* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2 *                         Patrick Schaaf <bof@bof.de>
3 * Copyright (C) 2003-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License version 2 as
7 * published by the Free Software Foundation.
8 */
9
10/* Kernel module for IP set management */
11
12#include <linux/init.h>
13#include <linux/module.h>
14#include <linux/moduleparam.h>
15#include <linux/ip.h>
16#include <linux/skbuff.h>
17#include <linux/spinlock.h>
18#include <linux/netlink.h>
19#include <linux/rculist.h>
20#include <net/netlink.h>
21
22#include <linux/netfilter.h>
23#include <linux/netfilter/x_tables.h>
24#include <linux/netfilter/nfnetlink.h>
25#include <linux/netfilter/ipset/ip_set.h>
26
27static LIST_HEAD(ip_set_type_list);		/* all registered set types */
28static DEFINE_MUTEX(ip_set_type_mutex);		/* protects ip_set_type_list */
29static DEFINE_RWLOCK(ip_set_ref_lock);		/* protects the set refs */
30
31static struct ip_set * __rcu *ip_set_list;	/* all individual sets */
32static ip_set_id_t ip_set_max = CONFIG_IP_SET_MAX; /* max number of sets */
33
34#define IP_SET_INC	64
35#define STREQ(a, b)	(strncmp(a, b, IPSET_MAXNAMELEN) == 0)
36
37static unsigned int max_sets;
38
39module_param(max_sets, int, 0600);
40MODULE_PARM_DESC(max_sets, "maximal number of sets");
41MODULE_LICENSE("GPL");
42MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
43MODULE_DESCRIPTION("core IP set support");
44MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
45
46/* When the nfnl mutex is held: */
47#define nfnl_dereference(p)		\
48	rcu_dereference_protected(p, 1)
49#define nfnl_set(id)			\
50	nfnl_dereference(ip_set_list)[id]
51
52/*
53 * The set types are implemented in modules and registered set types
54 * can be found in ip_set_type_list. Adding/deleting types is
55 * serialized by ip_set_type_mutex.
56 */
57
58static inline void
59ip_set_type_lock(void)
60{
61	mutex_lock(&ip_set_type_mutex);
62}
63
64static inline void
65ip_set_type_unlock(void)
66{
67	mutex_unlock(&ip_set_type_mutex);
68}
69
70/* Register and deregister settype */
71
72static struct ip_set_type *
73find_set_type(const char *name, u8 family, u8 revision)
74{
75	struct ip_set_type *type;
76
77	list_for_each_entry_rcu(type, &ip_set_type_list, list)
78		if (STREQ(type->name, name) &&
79		    (type->family == family ||
80		     type->family == NFPROTO_UNSPEC) &&
81		    revision >= type->revision_min &&
82		    revision <= type->revision_max)
83			return type;
84	return NULL;
85}
86
87/* Unlock, try to load a set type module and lock again */
88static bool
89load_settype(const char *name)
90{
91	nfnl_unlock();
92	pr_debug("try to load ip_set_%s\n", name);
93	if (request_module("ip_set_%s", name) < 0) {
94		pr_warning("Can't find ip_set type %s\n", name);
95		nfnl_lock();
96		return false;
97	}
98	nfnl_lock();
99	return true;
100}
101
102/* Find a set type and reference it */
103#define find_set_type_get(name, family, revision, found)	\
104	__find_set_type_get(name, family, revision, found, false)
105
106static int
107__find_set_type_get(const char *name, u8 family, u8 revision,
108		    struct ip_set_type **found, bool retry)
109{
110	struct ip_set_type *type;
111	int err;
112
113	if (retry && !load_settype(name))
114		return -IPSET_ERR_FIND_TYPE;
115
116	rcu_read_lock();
117	*found = find_set_type(name, family, revision);
118	if (*found) {
119		err = !try_module_get((*found)->me) ? -EFAULT : 0;
120		goto unlock;
121	}
122	/* Make sure the type is already loaded
123	 * but we don't support the revision */
124	list_for_each_entry_rcu(type, &ip_set_type_list, list)
125		if (STREQ(type->name, name)) {
126			err = -IPSET_ERR_FIND_TYPE;
127			goto unlock;
128		}
129	rcu_read_unlock();
130
131	return retry ? -IPSET_ERR_FIND_TYPE :
132		__find_set_type_get(name, family, revision, found, true);
133
134unlock:
135	rcu_read_unlock();
136	return err;
137}
138
139/* Find a given set type by name and family.
140 * If we succeeded, the supported minimal and maximum revisions are
141 * filled out.
142 */
143#define find_set_type_minmax(name, family, min, max) \
144	__find_set_type_minmax(name, family, min, max, false)
145
146static int
147__find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
148		       bool retry)
149{
150	struct ip_set_type *type;
151	bool found = false;
152
153	if (retry && !load_settype(name))
154		return -IPSET_ERR_FIND_TYPE;
155
156	*min = 255; *max = 0;
157	rcu_read_lock();
158	list_for_each_entry_rcu(type, &ip_set_type_list, list)
159		if (STREQ(type->name, name) &&
160		    (type->family == family ||
161		     type->family == NFPROTO_UNSPEC)) {
162			found = true;
163			if (type->revision_min < *min)
164				*min = type->revision_min;
165			if (type->revision_max > *max)
166				*max = type->revision_max;
167		}
168	rcu_read_unlock();
169	if (found)
170		return 0;
171
172	return retry ? -IPSET_ERR_FIND_TYPE :
173		__find_set_type_minmax(name, family, min, max, true);
174}
175
176#define family_name(f)	((f) == NFPROTO_IPV4 ? "inet" : \
177			 (f) == NFPROTO_IPV6 ? "inet6" : "any")
178
179/* Register a set type structure. The type is identified by
180 * the unique triple of name, family and revision.
181 */
182int
183ip_set_type_register(struct ip_set_type *type)
184{
185	int ret = 0;
186
187	if (type->protocol != IPSET_PROTOCOL) {
188		pr_warning("ip_set type %s, family %s, revision %u:%u uses "
189			   "wrong protocol version %u (want %u)\n",
190			   type->name, family_name(type->family),
191			   type->revision_min, type->revision_max,
192			   type->protocol, IPSET_PROTOCOL);
193		return -EINVAL;
194	}
195
196	ip_set_type_lock();
197	if (find_set_type(type->name, type->family, type->revision_min)) {
198		/* Duplicate! */
199		pr_warning("ip_set type %s, family %s with revision min %u "
200			   "already registered!\n", type->name,
201			   family_name(type->family), type->revision_min);
202		ret = -EINVAL;
203		goto unlock;
204	}
205	list_add_rcu(&type->list, &ip_set_type_list);
206	pr_debug("type %s, family %s, revision %u:%u registered.\n",
207		 type->name, family_name(type->family),
208		 type->revision_min, type->revision_max);
209unlock:
210	ip_set_type_unlock();
211	return ret;
212}
213EXPORT_SYMBOL_GPL(ip_set_type_register);
214
215/* Unregister a set type. There's a small race with ip_set_create */
216void
217ip_set_type_unregister(struct ip_set_type *type)
218{
219	ip_set_type_lock();
220	if (!find_set_type(type->name, type->family, type->revision_min)) {
221		pr_warning("ip_set type %s, family %s with revision min %u "
222			   "not registered\n", type->name,
223			   family_name(type->family), type->revision_min);
224		goto unlock;
225	}
226	list_del_rcu(&type->list);
227	pr_debug("type %s, family %s with revision min %u unregistered.\n",
228		 type->name, family_name(type->family), type->revision_min);
229unlock:
230	ip_set_type_unlock();
231
232	synchronize_rcu();
233}
234EXPORT_SYMBOL_GPL(ip_set_type_unregister);
235
236/* Utility functions */
237void *
238ip_set_alloc(size_t size)
239{
240	void *members = NULL;
241
242	if (size < KMALLOC_MAX_SIZE)
243		members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
244
245	if (members) {
246		pr_debug("%p: allocated with kmalloc\n", members);
247		return members;
248	}
249
250	members = vzalloc(size);
251	if (!members)
252		return NULL;
253	pr_debug("%p: allocated with vmalloc\n", members);
254
255	return members;
256}
257EXPORT_SYMBOL_GPL(ip_set_alloc);
258
259void
260ip_set_free(void *members)
261{
262	pr_debug("%p: free with %s\n", members,
263		 is_vmalloc_addr(members) ? "vfree" : "kfree");
264	if (is_vmalloc_addr(members))
265		vfree(members);
266	else
267		kfree(members);
268}
269EXPORT_SYMBOL_GPL(ip_set_free);
270
271static inline bool
272flag_nested(const struct nlattr *nla)
273{
274	return nla->nla_type & NLA_F_NESTED;
275}
276
277static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
278	[IPSET_ATTR_IPADDR_IPV4]	= { .type = NLA_U32 },
279	[IPSET_ATTR_IPADDR_IPV6]	= { .type = NLA_BINARY,
280					    .len = sizeof(struct in6_addr) },
281};
282
283int
284ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
285{
286	struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
287
288	if (unlikely(!flag_nested(nla)))
289		return -IPSET_ERR_PROTOCOL;
290	if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
291		return -IPSET_ERR_PROTOCOL;
292	if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
293		return -IPSET_ERR_PROTOCOL;
294
295	*ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
296	return 0;
297}
298EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
299
300int
301ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
302{
303	struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
304
305	if (unlikely(!flag_nested(nla)))
306		return -IPSET_ERR_PROTOCOL;
307
308	if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
309		return -IPSET_ERR_PROTOCOL;
310	if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
311		return -IPSET_ERR_PROTOCOL;
312
313	memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
314		sizeof(struct in6_addr));
315	return 0;
316}
317EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
318
319/*
320 * Creating/destroying/renaming/swapping affect the existence and
321 * the properties of a set. All of these can be executed from userspace
322 * only and serialized by the nfnl mutex indirectly from nfnetlink.
323 *
324 * Sets are identified by their index in ip_set_list and the index
325 * is used by the external references (set/SET netfilter modules).
326 *
327 * The set behind an index may change by swapping only, from userspace.
328 */
329
330static inline void
331__ip_set_get(struct ip_set *set)
332{
333	write_lock_bh(&ip_set_ref_lock);
334	set->ref++;
335	write_unlock_bh(&ip_set_ref_lock);
336}
337
338static inline void
339__ip_set_put(struct ip_set *set)
340{
341	write_lock_bh(&ip_set_ref_lock);
342	BUG_ON(set->ref == 0);
343	set->ref--;
344	write_unlock_bh(&ip_set_ref_lock);
345}
346
347/*
348 * Add, del and test set entries from kernel.
349 *
350 * The set behind the index must exist and must be referenced
351 * so it can't be destroyed (or changed) under our foot.
352 */
353
354static inline struct ip_set *
355ip_set_rcu_get(ip_set_id_t index)
356{
357	struct ip_set *set;
358
359	rcu_read_lock();
360	/* ip_set_list itself needs to be protected */
361	set = rcu_dereference(ip_set_list)[index];
362	rcu_read_unlock();
363
364	return set;
365}
366
367int
368ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
369	    const struct xt_action_param *par,
370	    const struct ip_set_adt_opt *opt)
371{
372	struct ip_set *set = ip_set_rcu_get(index);
373	int ret = 0;
374
375	BUG_ON(set == NULL);
376	pr_debug("set %s, index %u\n", set->name, index);
377
378	if (opt->dim < set->type->dimension ||
379	    !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
380		return 0;
381
382	read_lock_bh(&set->lock);
383	ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
384	read_unlock_bh(&set->lock);
385
386	if (ret == -EAGAIN) {
387		/* Type requests element to be completed */
388		pr_debug("element must be competed, ADD is triggered\n");
389		write_lock_bh(&set->lock);
390		set->variant->kadt(set, skb, par, IPSET_ADD, opt);
391		write_unlock_bh(&set->lock);
392		ret = 1;
393	} else {
394		/* --return-nomatch: invert matched element */
395		if ((opt->flags & IPSET_RETURN_NOMATCH) &&
396		    (set->type->features & IPSET_TYPE_NOMATCH) &&
397		    (ret > 0 || ret == -ENOTEMPTY))
398			ret = -ret;
399	}
400
401	/* Convert error codes to nomatch */
402	return (ret < 0 ? 0 : ret);
403}
404EXPORT_SYMBOL_GPL(ip_set_test);
405
406int
407ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
408	   const struct xt_action_param *par,
409	   const struct ip_set_adt_opt *opt)
410{
411	struct ip_set *set = ip_set_rcu_get(index);
412	int ret;
413
414	BUG_ON(set == NULL);
415	pr_debug("set %s, index %u\n", set->name, index);
416
417	if (opt->dim < set->type->dimension ||
418	    !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
419		return 0;
420
421	write_lock_bh(&set->lock);
422	ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
423	write_unlock_bh(&set->lock);
424
425	return ret;
426}
427EXPORT_SYMBOL_GPL(ip_set_add);
428
429int
430ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
431	   const struct xt_action_param *par,
432	   const struct ip_set_adt_opt *opt)
433{
434	struct ip_set *set = ip_set_rcu_get(index);
435	int ret = 0;
436
437	BUG_ON(set == NULL);
438	pr_debug("set %s, index %u\n", set->name, index);
439
440	if (opt->dim < set->type->dimension ||
441	    !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
442		return 0;
443
444	write_lock_bh(&set->lock);
445	ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
446	write_unlock_bh(&set->lock);
447
448	return ret;
449}
450EXPORT_SYMBOL_GPL(ip_set_del);
451
452/*
453 * Find set by name, reference it once. The reference makes sure the
454 * thing pointed to, does not go away under our feet.
455 *
456 */
457ip_set_id_t
458ip_set_get_byname(const char *name, struct ip_set **set)
459{
460	ip_set_id_t i, index = IPSET_INVALID_ID;
461	struct ip_set *s;
462
463	rcu_read_lock();
464	for (i = 0; i < ip_set_max; i++) {
465		s = rcu_dereference(ip_set_list)[i];
466		if (s != NULL && STREQ(s->name, name)) {
467			__ip_set_get(s);
468			index = i;
469			*set = s;
470			break;
471		}
472	}
473	rcu_read_unlock();
474
475	return index;
476}
477EXPORT_SYMBOL_GPL(ip_set_get_byname);
478
479/*
480 * If the given set pointer points to a valid set, decrement
481 * reference count by 1. The caller shall not assume the index
482 * to be valid, after calling this function.
483 *
484 */
485void
486ip_set_put_byindex(ip_set_id_t index)
487{
488	struct ip_set *set;
489
490	rcu_read_lock();
491	set = rcu_dereference(ip_set_list)[index];
492	if (set != NULL)
493		__ip_set_put(set);
494	rcu_read_unlock();
495}
496EXPORT_SYMBOL_GPL(ip_set_put_byindex);
497
498/*
499 * Get the name of a set behind a set index.
500 * We assume the set is referenced, so it does exist and
501 * can't be destroyed. The set cannot be renamed due to
502 * the referencing either.
503 *
504 */
505const char *
506ip_set_name_byindex(ip_set_id_t index)
507{
508	const struct ip_set *set = ip_set_rcu_get(index);
509
510	BUG_ON(set == NULL);
511	BUG_ON(set->ref == 0);
512
513	/* Referenced, so it's safe */
514	return set->name;
515}
516EXPORT_SYMBOL_GPL(ip_set_name_byindex);
517
518/*
519 * Routines to call by external subsystems, which do not
520 * call nfnl_lock for us.
521 */
522
523/*
524 * Find set by name, reference it once. The reference makes sure the
525 * thing pointed to, does not go away under our feet.
526 *
527 * The nfnl mutex is used in the function.
528 */
529ip_set_id_t
530ip_set_nfnl_get(const char *name)
531{
532	ip_set_id_t i, index = IPSET_INVALID_ID;
533	struct ip_set *s;
534
535	nfnl_lock();
536	for (i = 0; i < ip_set_max; i++) {
537		s = nfnl_set(i);
538		if (s != NULL && STREQ(s->name, name)) {
539			__ip_set_get(s);
540			index = i;
541			break;
542		}
543	}
544	nfnl_unlock();
545
546	return index;
547}
548EXPORT_SYMBOL_GPL(ip_set_nfnl_get);
549
550/*
551 * Find set by index, reference it once. The reference makes sure the
552 * thing pointed to, does not go away under our feet.
553 *
554 * The nfnl mutex is used in the function.
555 */
556ip_set_id_t
557ip_set_nfnl_get_byindex(ip_set_id_t index)
558{
559	struct ip_set *set;
560
561	if (index > ip_set_max)
562		return IPSET_INVALID_ID;
563
564	nfnl_lock();
565	set = nfnl_set(index);
566	if (set)
567		__ip_set_get(set);
568	else
569		index = IPSET_INVALID_ID;
570	nfnl_unlock();
571
572	return index;
573}
574EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
575
576/*
577 * If the given set pointer points to a valid set, decrement
578 * reference count by 1. The caller shall not assume the index
579 * to be valid, after calling this function.
580 *
581 * The nfnl mutex is used in the function.
582 */
583void
584ip_set_nfnl_put(ip_set_id_t index)
585{
586	struct ip_set *set;
587	nfnl_lock();
588	set = nfnl_set(index);
589	if (set != NULL)
590		__ip_set_put(set);
591	nfnl_unlock();
592}
593EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
594
595/*
596 * Communication protocol with userspace over netlink.
597 *
598 * The commands are serialized by the nfnl mutex.
599 */
600
601static inline bool
602protocol_failed(const struct nlattr * const tb[])
603{
604	return !tb[IPSET_ATTR_PROTOCOL] ||
605	       nla_get_u8(tb[IPSET_ATTR_PROTOCOL]) != IPSET_PROTOCOL;
606}
607
608static inline u32
609flag_exist(const struct nlmsghdr *nlh)
610{
611	return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
612}
613
614static struct nlmsghdr *
615start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
616	  enum ipset_cmd cmd)
617{
618	struct nlmsghdr *nlh;
619	struct nfgenmsg *nfmsg;
620
621	nlh = nlmsg_put(skb, portid, seq, cmd | (NFNL_SUBSYS_IPSET << 8),
622			sizeof(*nfmsg), flags);
623	if (nlh == NULL)
624		return NULL;
625
626	nfmsg = nlmsg_data(nlh);
627	nfmsg->nfgen_family = NFPROTO_IPV4;
628	nfmsg->version = NFNETLINK_V0;
629	nfmsg->res_id = 0;
630
631	return nlh;
632}
633
634/* Create a set */
635
636static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
637	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
638	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
639				    .len = IPSET_MAXNAMELEN - 1 },
640	[IPSET_ATTR_TYPENAME]	= { .type = NLA_NUL_STRING,
641				    .len = IPSET_MAXNAMELEN - 1},
642	[IPSET_ATTR_REVISION]	= { .type = NLA_U8 },
643	[IPSET_ATTR_FAMILY]	= { .type = NLA_U8 },
644	[IPSET_ATTR_DATA]	= { .type = NLA_NESTED },
645};
646
647static struct ip_set *
648find_set_and_id(const char *name, ip_set_id_t *id)
649{
650	struct ip_set *set = NULL;
651	ip_set_id_t i;
652
653	*id = IPSET_INVALID_ID;
654	for (i = 0; i < ip_set_max; i++) {
655		set = nfnl_set(i);
656		if (set != NULL && STREQ(set->name, name)) {
657			*id = i;
658			break;
659		}
660	}
661	return (*id == IPSET_INVALID_ID ? NULL : set);
662}
663
664static inline struct ip_set *
665find_set(const char *name)
666{
667	ip_set_id_t id;
668
669	return find_set_and_id(name, &id);
670}
671
672static int
673find_free_id(const char *name, ip_set_id_t *index, struct ip_set **set)
674{
675	struct ip_set *s;
676	ip_set_id_t i;
677
678	*index = IPSET_INVALID_ID;
679	for (i = 0;  i < ip_set_max; i++) {
680		s = nfnl_set(i);
681		if (s == NULL) {
682			if (*index == IPSET_INVALID_ID)
683				*index = i;
684		} else if (STREQ(name, s->name)) {
685			/* Name clash */
686			*set = s;
687			return -EEXIST;
688		}
689	}
690	if (*index == IPSET_INVALID_ID)
691		/* No free slot remained */
692		return -IPSET_ERR_MAX_SETS;
693	return 0;
694}
695
696static int
697ip_set_none(struct sock *ctnl, struct sk_buff *skb,
698	    const struct nlmsghdr *nlh,
699	    const struct nlattr * const attr[])
700{
701	return -EOPNOTSUPP;
702}
703
704static int
705ip_set_create(struct sock *ctnl, struct sk_buff *skb,
706	      const struct nlmsghdr *nlh,
707	      const struct nlattr * const attr[])
708{
709	struct ip_set *set, *clash = NULL;
710	ip_set_id_t index = IPSET_INVALID_ID;
711	struct nlattr *tb[IPSET_ATTR_CREATE_MAX+1] = {};
712	const char *name, *typename;
713	u8 family, revision;
714	u32 flags = flag_exist(nlh);
715	int ret = 0;
716
717	if (unlikely(protocol_failed(attr) ||
718		     attr[IPSET_ATTR_SETNAME] == NULL ||
719		     attr[IPSET_ATTR_TYPENAME] == NULL ||
720		     attr[IPSET_ATTR_REVISION] == NULL ||
721		     attr[IPSET_ATTR_FAMILY] == NULL ||
722		     (attr[IPSET_ATTR_DATA] != NULL &&
723		      !flag_nested(attr[IPSET_ATTR_DATA]))))
724		return -IPSET_ERR_PROTOCOL;
725
726	name = nla_data(attr[IPSET_ATTR_SETNAME]);
727	typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
728	family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
729	revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
730	pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
731		 name, typename, family_name(family), revision);
732
733	/*
734	 * First, and without any locks, allocate and initialize
735	 * a normal base set structure.
736	 */
737	set = kzalloc(sizeof(struct ip_set), GFP_KERNEL);
738	if (!set)
739		return -ENOMEM;
740	rwlock_init(&set->lock);
741	strlcpy(set->name, name, IPSET_MAXNAMELEN);
742	set->family = family;
743	set->revision = revision;
744
745	/*
746	 * Next, check that we know the type, and take
747	 * a reference on the type, to make sure it stays available
748	 * while constructing our new set.
749	 *
750	 * After referencing the type, we try to create the type
751	 * specific part of the set without holding any locks.
752	 */
753	ret = find_set_type_get(typename, family, revision, &(set->type));
754	if (ret)
755		goto out;
756
757	/*
758	 * Without holding any locks, create private part.
759	 */
760	if (attr[IPSET_ATTR_DATA] &&
761	    nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
762			     set->type->create_policy)) {
763		ret = -IPSET_ERR_PROTOCOL;
764		goto put_out;
765	}
766
767	ret = set->type->create(set, tb, flags);
768	if (ret != 0)
769		goto put_out;
770
771	/* BTW, ret==0 here. */
772
773	/*
774	 * Here, we have a valid, constructed set and we are protected
775	 * by the nfnl mutex. Find the first free index in ip_set_list
776	 * and check clashing.
777	 */
778	ret = find_free_id(set->name, &index, &clash);
779	if (ret == -EEXIST) {
780		/* If this is the same set and requested, ignore error */
781		if ((flags & IPSET_FLAG_EXIST) &&
782		    STREQ(set->type->name, clash->type->name) &&
783		    set->type->family == clash->type->family &&
784		    set->type->revision_min == clash->type->revision_min &&
785		    set->type->revision_max == clash->type->revision_max &&
786		    set->variant->same_set(set, clash))
787			ret = 0;
788		goto cleanup;
789	} else if (ret == -IPSET_ERR_MAX_SETS) {
790		struct ip_set **list, **tmp;
791		ip_set_id_t i = ip_set_max + IP_SET_INC;
792
793		if (i < ip_set_max || i == IPSET_INVALID_ID)
794			/* Wraparound */
795			goto cleanup;
796
797		list = kzalloc(sizeof(struct ip_set *) * i, GFP_KERNEL);
798		if (!list)
799			goto cleanup;
800		/* nfnl mutex is held, both lists are valid */
801		tmp = nfnl_dereference(ip_set_list);
802		memcpy(list, tmp, sizeof(struct ip_set *) * ip_set_max);
803		rcu_assign_pointer(ip_set_list, list);
804		/* Make sure all current packets have passed through */
805		synchronize_net();
806		/* Use new list */
807		index = ip_set_max;
808		ip_set_max = i;
809		kfree(tmp);
810		ret = 0;
811	} else if (ret)
812		goto cleanup;
813
814	/*
815	 * Finally! Add our shiny new set to the list, and be done.
816	 */
817	pr_debug("create: '%s' created with index %u!\n", set->name, index);
818	nfnl_set(index) = set;
819
820	return ret;
821
822cleanup:
823	set->variant->destroy(set);
824put_out:
825	module_put(set->type->me);
826out:
827	kfree(set);
828	return ret;
829}
830
831/* Destroy sets */
832
833static const struct nla_policy
834ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
835	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
836	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
837				    .len = IPSET_MAXNAMELEN - 1 },
838};
839
840static void
841ip_set_destroy_set(ip_set_id_t index)
842{
843	struct ip_set *set = nfnl_set(index);
844
845	pr_debug("set: %s\n",  set->name);
846	nfnl_set(index) = NULL;
847
848	/* Must call it without holding any lock */
849	set->variant->destroy(set);
850	module_put(set->type->me);
851	kfree(set);
852}
853
854static int
855ip_set_destroy(struct sock *ctnl, struct sk_buff *skb,
856	       const struct nlmsghdr *nlh,
857	       const struct nlattr * const attr[])
858{
859	struct ip_set *s;
860	ip_set_id_t i;
861	int ret = 0;
862
863	if (unlikely(protocol_failed(attr)))
864		return -IPSET_ERR_PROTOCOL;
865
866	/* Commands are serialized and references are
867	 * protected by the ip_set_ref_lock.
868	 * External systems (i.e. xt_set) must call
869	 * ip_set_put|get_nfnl_* functions, that way we
870	 * can safely check references here.
871	 *
872	 * list:set timer can only decrement the reference
873	 * counter, so if it's already zero, we can proceed
874	 * without holding the lock.
875	 */
876	read_lock_bh(&ip_set_ref_lock);
877	if (!attr[IPSET_ATTR_SETNAME]) {
878		for (i = 0; i < ip_set_max; i++) {
879			s = nfnl_set(i);
880			if (s != NULL && s->ref) {
881				ret = -IPSET_ERR_BUSY;
882				goto out;
883			}
884		}
885		read_unlock_bh(&ip_set_ref_lock);
886		for (i = 0; i < ip_set_max; i++) {
887			s = nfnl_set(i);
888			if (s != NULL)
889				ip_set_destroy_set(i);
890		}
891	} else {
892		s = find_set_and_id(nla_data(attr[IPSET_ATTR_SETNAME]), &i);
893		if (s == NULL) {
894			ret = -ENOENT;
895			goto out;
896		} else if (s->ref) {
897			ret = -IPSET_ERR_BUSY;
898			goto out;
899		}
900		read_unlock_bh(&ip_set_ref_lock);
901
902		ip_set_destroy_set(i);
903	}
904	return 0;
905out:
906	read_unlock_bh(&ip_set_ref_lock);
907	return ret;
908}
909
910/* Flush sets */
911
912static void
913ip_set_flush_set(struct ip_set *set)
914{
915	pr_debug("set: %s\n",  set->name);
916
917	write_lock_bh(&set->lock);
918	set->variant->flush(set);
919	write_unlock_bh(&set->lock);
920}
921
922static int
923ip_set_flush(struct sock *ctnl, struct sk_buff *skb,
924	     const struct nlmsghdr *nlh,
925	     const struct nlattr * const attr[])
926{
927	struct ip_set *s;
928	ip_set_id_t i;
929
930	if (unlikely(protocol_failed(attr)))
931		return -IPSET_ERR_PROTOCOL;
932
933	if (!attr[IPSET_ATTR_SETNAME]) {
934		for (i = 0; i < ip_set_max; i++) {
935			s = nfnl_set(i);
936			if (s != NULL)
937				ip_set_flush_set(s);
938		}
939	} else {
940		s = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
941		if (s == NULL)
942			return -ENOENT;
943
944		ip_set_flush_set(s);
945	}
946
947	return 0;
948}
949
950/* Rename a set */
951
952static const struct nla_policy
953ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
954	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
955	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
956				    .len = IPSET_MAXNAMELEN - 1 },
957	[IPSET_ATTR_SETNAME2]	= { .type = NLA_NUL_STRING,
958				    .len = IPSET_MAXNAMELEN - 1 },
959};
960
961static int
962ip_set_rename(struct sock *ctnl, struct sk_buff *skb,
963	      const struct nlmsghdr *nlh,
964	      const struct nlattr * const attr[])
965{
966	struct ip_set *set, *s;
967	const char *name2;
968	ip_set_id_t i;
969	int ret = 0;
970
971	if (unlikely(protocol_failed(attr) ||
972		     attr[IPSET_ATTR_SETNAME] == NULL ||
973		     attr[IPSET_ATTR_SETNAME2] == NULL))
974		return -IPSET_ERR_PROTOCOL;
975
976	set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
977	if (set == NULL)
978		return -ENOENT;
979
980	read_lock_bh(&ip_set_ref_lock);
981	if (set->ref != 0) {
982		ret = -IPSET_ERR_REFERENCED;
983		goto out;
984	}
985
986	name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
987	for (i = 0; i < ip_set_max; i++) {
988		s = nfnl_set(i);
989		if (s != NULL && STREQ(s->name, name2)) {
990			ret = -IPSET_ERR_EXIST_SETNAME2;
991			goto out;
992		}
993	}
994	strncpy(set->name, name2, IPSET_MAXNAMELEN);
995
996out:
997	read_unlock_bh(&ip_set_ref_lock);
998	return ret;
999}
1000
1001/* Swap two sets so that name/index points to the other.
1002 * References and set names are also swapped.
1003 *
1004 * The commands are serialized by the nfnl mutex and references are
1005 * protected by the ip_set_ref_lock. The kernel interfaces
1006 * do not hold the mutex but the pointer settings are atomic
1007 * so the ip_set_list always contains valid pointers to the sets.
1008 */
1009
1010static int
1011ip_set_swap(struct sock *ctnl, struct sk_buff *skb,
1012	    const struct nlmsghdr *nlh,
1013	    const struct nlattr * const attr[])
1014{
1015	struct ip_set *from, *to;
1016	ip_set_id_t from_id, to_id;
1017	char from_name[IPSET_MAXNAMELEN];
1018
1019	if (unlikely(protocol_failed(attr) ||
1020		     attr[IPSET_ATTR_SETNAME] == NULL ||
1021		     attr[IPSET_ATTR_SETNAME2] == NULL))
1022		return -IPSET_ERR_PROTOCOL;
1023
1024	from = find_set_and_id(nla_data(attr[IPSET_ATTR_SETNAME]), &from_id);
1025	if (from == NULL)
1026		return -ENOENT;
1027
1028	to = find_set_and_id(nla_data(attr[IPSET_ATTR_SETNAME2]), &to_id);
1029	if (to == NULL)
1030		return -IPSET_ERR_EXIST_SETNAME2;
1031
1032	/* Features must not change.
1033	 * Not an artificial restriction anymore, as we must prevent
1034	 * possible loops created by swapping in setlist type of sets. */
1035	if (!(from->type->features == to->type->features &&
1036	      from->type->family == to->type->family))
1037		return -IPSET_ERR_TYPE_MISMATCH;
1038
1039	strncpy(from_name, from->name, IPSET_MAXNAMELEN);
1040	strncpy(from->name, to->name, IPSET_MAXNAMELEN);
1041	strncpy(to->name, from_name, IPSET_MAXNAMELEN);
1042
1043	write_lock_bh(&ip_set_ref_lock);
1044	swap(from->ref, to->ref);
1045	nfnl_set(from_id) = to;
1046	nfnl_set(to_id) = from;
1047	write_unlock_bh(&ip_set_ref_lock);
1048
1049	return 0;
1050}
1051
1052/* List/save set data */
1053
1054#define DUMP_INIT	0
1055#define DUMP_ALL	1
1056#define DUMP_ONE	2
1057#define DUMP_LAST	3
1058
1059#define DUMP_TYPE(arg)		(((u32)(arg)) & 0x0000FFFF)
1060#define DUMP_FLAGS(arg)		(((u32)(arg)) >> 16)
1061
1062static int
1063ip_set_dump_done(struct netlink_callback *cb)
1064{
1065	if (cb->args[2]) {
1066		pr_debug("release set %s\n", nfnl_set(cb->args[1])->name);
1067		ip_set_put_byindex((ip_set_id_t) cb->args[1]);
1068	}
1069	return 0;
1070}
1071
1072static inline void
1073dump_attrs(struct nlmsghdr *nlh)
1074{
1075	const struct nlattr *attr;
1076	int rem;
1077
1078	pr_debug("dump nlmsg\n");
1079	nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1080		pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1081	}
1082}
1083
1084static int
1085dump_init(struct netlink_callback *cb)
1086{
1087	struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1088	int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
1089	struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1090	struct nlattr *attr = (void *)nlh + min_len;
1091	u32 dump_type;
1092	ip_set_id_t index;
1093
1094	/* Second pass, so parser can't fail */
1095	nla_parse(cda, IPSET_ATTR_CMD_MAX,
1096		  attr, nlh->nlmsg_len - min_len, ip_set_setname_policy);
1097
1098	/* cb->args[0] : dump single set/all sets
1099	 *         [1] : set index
1100	 *         [..]: type specific
1101	 */
1102
1103	if (cda[IPSET_ATTR_SETNAME]) {
1104		struct ip_set *set;
1105
1106		set = find_set_and_id(nla_data(cda[IPSET_ATTR_SETNAME]),
1107				      &index);
1108		if (set == NULL)
1109			return -ENOENT;
1110
1111		dump_type = DUMP_ONE;
1112		cb->args[1] = index;
1113	} else
1114		dump_type = DUMP_ALL;
1115
1116	if (cda[IPSET_ATTR_FLAGS]) {
1117		u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1118		dump_type |= (f << 16);
1119	}
1120	cb->args[0] = dump_type;
1121
1122	return 0;
1123}
1124
1125static int
1126ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1127{
1128	ip_set_id_t index = IPSET_INVALID_ID, max;
1129	struct ip_set *set = NULL;
1130	struct nlmsghdr *nlh = NULL;
1131	unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0;
1132	u32 dump_type, dump_flags;
1133	int ret = 0;
1134
1135	if (!cb->args[0]) {
1136		ret = dump_init(cb);
1137		if (ret < 0) {
1138			nlh = nlmsg_hdr(cb->skb);
1139			/* We have to create and send the error message
1140			 * manually :-( */
1141			if (nlh->nlmsg_flags & NLM_F_ACK)
1142				netlink_ack(cb->skb, nlh, ret);
1143			return ret;
1144		}
1145	}
1146
1147	if (cb->args[1] >= ip_set_max)
1148		goto out;
1149
1150	dump_type = DUMP_TYPE(cb->args[0]);
1151	dump_flags = DUMP_FLAGS(cb->args[0]);
1152	max = dump_type == DUMP_ONE ? cb->args[1] + 1 : ip_set_max;
1153dump_last:
1154	pr_debug("args[0]: %u %u args[1]: %ld\n",
1155		 dump_type, dump_flags, cb->args[1]);
1156	for (; cb->args[1] < max; cb->args[1]++) {
1157		index = (ip_set_id_t) cb->args[1];
1158		set = nfnl_set(index);
1159		if (set == NULL) {
1160			if (dump_type == DUMP_ONE) {
1161				ret = -ENOENT;
1162				goto out;
1163			}
1164			continue;
1165		}
1166		/* When dumping all sets, we must dump "sorted"
1167		 * so that lists (unions of sets) are dumped last.
1168		 */
1169		if (dump_type != DUMP_ONE &&
1170		    ((dump_type == DUMP_ALL) ==
1171		     !!(set->type->features & IPSET_DUMP_LAST)))
1172			continue;
1173		pr_debug("List set: %s\n", set->name);
1174		if (!cb->args[2]) {
1175			/* Start listing: make sure set won't be destroyed */
1176			pr_debug("reference set\n");
1177			__ip_set_get(set);
1178		}
1179		nlh = start_msg(skb, NETLINK_CB(cb->skb).portid,
1180				cb->nlh->nlmsg_seq, flags,
1181				IPSET_CMD_LIST);
1182		if (!nlh) {
1183			ret = -EMSGSIZE;
1184			goto release_refcount;
1185		}
1186		if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL) ||
1187		    nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1188			goto nla_put_failure;
1189		if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1190			goto next_set;
1191		switch (cb->args[2]) {
1192		case 0:
1193			/* Core header data */
1194			if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1195					   set->type->name) ||
1196			    nla_put_u8(skb, IPSET_ATTR_FAMILY,
1197				       set->family) ||
1198			    nla_put_u8(skb, IPSET_ATTR_REVISION,
1199				       set->revision))
1200				goto nla_put_failure;
1201			ret = set->variant->head(set, skb);
1202			if (ret < 0)
1203				goto release_refcount;
1204			if (dump_flags & IPSET_FLAG_LIST_HEADER)
1205				goto next_set;
1206			/* Fall through and add elements */
1207		default:
1208			read_lock_bh(&set->lock);
1209			ret = set->variant->list(set, skb, cb);
1210			read_unlock_bh(&set->lock);
1211			if (!cb->args[2])
1212				/* Set is done, proceed with next one */
1213				goto next_set;
1214			goto release_refcount;
1215		}
1216	}
1217	/* If we dump all sets, continue with dumping last ones */
1218	if (dump_type == DUMP_ALL) {
1219		dump_type = DUMP_LAST;
1220		cb->args[0] = dump_type | (dump_flags << 16);
1221		cb->args[1] = 0;
1222		goto dump_last;
1223	}
1224	goto out;
1225
1226nla_put_failure:
1227	ret = -EFAULT;
1228next_set:
1229	if (dump_type == DUMP_ONE)
1230		cb->args[1] = IPSET_INVALID_ID;
1231	else
1232		cb->args[1]++;
1233release_refcount:
1234	/* If there was an error or set is done, release set */
1235	if (ret || !cb->args[2]) {
1236		pr_debug("release set %s\n", nfnl_set(index)->name);
1237		ip_set_put_byindex(index);
1238		cb->args[2] = 0;
1239	}
1240out:
1241	if (nlh) {
1242		nlmsg_end(skb, nlh);
1243		pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1244		dump_attrs(nlh);
1245	}
1246
1247	return ret < 0 ? ret : skb->len;
1248}
1249
1250static int
1251ip_set_dump(struct sock *ctnl, struct sk_buff *skb,
1252	    const struct nlmsghdr *nlh,
1253	    const struct nlattr * const attr[])
1254{
1255	if (unlikely(protocol_failed(attr)))
1256		return -IPSET_ERR_PROTOCOL;
1257
1258	{
1259		struct netlink_dump_control c = {
1260			.dump = ip_set_dump_start,
1261			.done = ip_set_dump_done,
1262		};
1263		return netlink_dump_start(ctnl, skb, nlh, &c);
1264	}
1265}
1266
1267/* Add, del and test */
1268
1269static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1270	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1271	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
1272				    .len = IPSET_MAXNAMELEN - 1 },
1273	[IPSET_ATTR_LINENO]	= { .type = NLA_U32 },
1274	[IPSET_ATTR_DATA]	= { .type = NLA_NESTED },
1275	[IPSET_ATTR_ADT]	= { .type = NLA_NESTED },
1276};
1277
1278static int
1279call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1280	struct nlattr *tb[], enum ipset_adt adt,
1281	u32 flags, bool use_lineno)
1282{
1283	int ret;
1284	u32 lineno = 0;
1285	bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1286
1287	do {
1288		write_lock_bh(&set->lock);
1289		ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1290		write_unlock_bh(&set->lock);
1291		retried = true;
1292	} while (ret == -EAGAIN &&
1293		 set->variant->resize &&
1294		 (ret = set->variant->resize(set, retried)) == 0);
1295
1296	if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1297		return 0;
1298	if (lineno && use_lineno) {
1299		/* Error in restore/batch mode: send back lineno */
1300		struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1301		struct sk_buff *skb2;
1302		struct nlmsgerr *errmsg;
1303		size_t payload = sizeof(*errmsg) + nlmsg_len(nlh);
1304		int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
1305		struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1306		struct nlattr *cmdattr;
1307		u32 *errline;
1308
1309		skb2 = nlmsg_new(payload, GFP_KERNEL);
1310		if (skb2 == NULL)
1311			return -ENOMEM;
1312		rep = __nlmsg_put(skb2, NETLINK_CB(skb).portid,
1313				  nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1314		errmsg = nlmsg_data(rep);
1315		errmsg->error = ret;
1316		memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1317		cmdattr = (void *)&errmsg->msg + min_len;
1318
1319		nla_parse(cda, IPSET_ATTR_CMD_MAX,
1320			  cmdattr, nlh->nlmsg_len - min_len,
1321			  ip_set_adt_policy);
1322
1323		errline = nla_data(cda[IPSET_ATTR_LINENO]);
1324
1325		*errline = lineno;
1326
1327		netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1328		/* Signal netlink not to send its ACK/errmsg.  */
1329		return -EINTR;
1330	}
1331
1332	return ret;
1333}
1334
1335static int
1336ip_set_uadd(struct sock *ctnl, struct sk_buff *skb,
1337	    const struct nlmsghdr *nlh,
1338	    const struct nlattr * const attr[])
1339{
1340	struct ip_set *set;
1341	struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1342	const struct nlattr *nla;
1343	u32 flags = flag_exist(nlh);
1344	bool use_lineno;
1345	int ret = 0;
1346
1347	if (unlikely(protocol_failed(attr) ||
1348		     attr[IPSET_ATTR_SETNAME] == NULL ||
1349		     !((attr[IPSET_ATTR_DATA] != NULL) ^
1350		       (attr[IPSET_ATTR_ADT] != NULL)) ||
1351		     (attr[IPSET_ATTR_DATA] != NULL &&
1352		      !flag_nested(attr[IPSET_ATTR_DATA])) ||
1353		     (attr[IPSET_ATTR_ADT] != NULL &&
1354		      (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1355		       attr[IPSET_ATTR_LINENO] == NULL))))
1356		return -IPSET_ERR_PROTOCOL;
1357
1358	set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1359	if (set == NULL)
1360		return -ENOENT;
1361
1362	use_lineno = !!attr[IPSET_ATTR_LINENO];
1363	if (attr[IPSET_ATTR_DATA]) {
1364		if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1365				     attr[IPSET_ATTR_DATA],
1366				     set->type->adt_policy))
1367			return -IPSET_ERR_PROTOCOL;
1368		ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1369			      use_lineno);
1370	} else {
1371		int nla_rem;
1372
1373		nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1374			memset(tb, 0, sizeof(tb));
1375			if (nla_type(nla) != IPSET_ATTR_DATA ||
1376			    !flag_nested(nla) ||
1377			    nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1378					     set->type->adt_policy))
1379				return -IPSET_ERR_PROTOCOL;
1380			ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1381				      flags, use_lineno);
1382			if (ret < 0)
1383				return ret;
1384		}
1385	}
1386	return ret;
1387}
1388
1389static int
1390ip_set_udel(struct sock *ctnl, struct sk_buff *skb,
1391	    const struct nlmsghdr *nlh,
1392	    const struct nlattr * const attr[])
1393{
1394	struct ip_set *set;
1395	struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1396	const struct nlattr *nla;
1397	u32 flags = flag_exist(nlh);
1398	bool use_lineno;
1399	int ret = 0;
1400
1401	if (unlikely(protocol_failed(attr) ||
1402		     attr[IPSET_ATTR_SETNAME] == NULL ||
1403		     !((attr[IPSET_ATTR_DATA] != NULL) ^
1404		       (attr[IPSET_ATTR_ADT] != NULL)) ||
1405		     (attr[IPSET_ATTR_DATA] != NULL &&
1406		      !flag_nested(attr[IPSET_ATTR_DATA])) ||
1407		     (attr[IPSET_ATTR_ADT] != NULL &&
1408		      (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1409		       attr[IPSET_ATTR_LINENO] == NULL))))
1410		return -IPSET_ERR_PROTOCOL;
1411
1412	set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1413	if (set == NULL)
1414		return -ENOENT;
1415
1416	use_lineno = !!attr[IPSET_ATTR_LINENO];
1417	if (attr[IPSET_ATTR_DATA]) {
1418		if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1419				     attr[IPSET_ATTR_DATA],
1420				     set->type->adt_policy))
1421			return -IPSET_ERR_PROTOCOL;
1422		ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1423			      use_lineno);
1424	} else {
1425		int nla_rem;
1426
1427		nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1428			memset(tb, 0, sizeof(*tb));
1429			if (nla_type(nla) != IPSET_ATTR_DATA ||
1430			    !flag_nested(nla) ||
1431			    nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1432					     set->type->adt_policy))
1433				return -IPSET_ERR_PROTOCOL;
1434			ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1435				      flags, use_lineno);
1436			if (ret < 0)
1437				return ret;
1438		}
1439	}
1440	return ret;
1441}
1442
1443static int
1444ip_set_utest(struct sock *ctnl, struct sk_buff *skb,
1445	     const struct nlmsghdr *nlh,
1446	     const struct nlattr * const attr[])
1447{
1448	struct ip_set *set;
1449	struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1450	int ret = 0;
1451
1452	if (unlikely(protocol_failed(attr) ||
1453		     attr[IPSET_ATTR_SETNAME] == NULL ||
1454		     attr[IPSET_ATTR_DATA] == NULL ||
1455		     !flag_nested(attr[IPSET_ATTR_DATA])))
1456		return -IPSET_ERR_PROTOCOL;
1457
1458	set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1459	if (set == NULL)
1460		return -ENOENT;
1461
1462	if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1463			     set->type->adt_policy))
1464		return -IPSET_ERR_PROTOCOL;
1465
1466	read_lock_bh(&set->lock);
1467	ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1468	read_unlock_bh(&set->lock);
1469	/* Userspace can't trigger element to be re-added */
1470	if (ret == -EAGAIN)
1471		ret = 1;
1472
1473	return ret < 0 ? ret : ret > 0 ? 0 : -IPSET_ERR_EXIST;
1474}
1475
1476/* Get headed data of a set */
1477
1478static int
1479ip_set_header(struct sock *ctnl, struct sk_buff *skb,
1480	      const struct nlmsghdr *nlh,
1481	      const struct nlattr * const attr[])
1482{
1483	const struct ip_set *set;
1484	struct sk_buff *skb2;
1485	struct nlmsghdr *nlh2;
1486	int ret = 0;
1487
1488	if (unlikely(protocol_failed(attr) ||
1489		     attr[IPSET_ATTR_SETNAME] == NULL))
1490		return -IPSET_ERR_PROTOCOL;
1491
1492	set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1493	if (set == NULL)
1494		return -ENOENT;
1495
1496	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1497	if (skb2 == NULL)
1498		return -ENOMEM;
1499
1500	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1501			 IPSET_CMD_HEADER);
1502	if (!nlh2)
1503		goto nlmsg_failure;
1504	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL) ||
1505	    nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1506	    nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1507	    nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1508	    nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1509		goto nla_put_failure;
1510	nlmsg_end(skb2, nlh2);
1511
1512	ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1513	if (ret < 0)
1514		return ret;
1515
1516	return 0;
1517
1518nla_put_failure:
1519	nlmsg_cancel(skb2, nlh2);
1520nlmsg_failure:
1521	kfree_skb(skb2);
1522	return -EMSGSIZE;
1523}
1524
1525/* Get type data */
1526
1527static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1528	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1529	[IPSET_ATTR_TYPENAME]	= { .type = NLA_NUL_STRING,
1530				    .len = IPSET_MAXNAMELEN - 1 },
1531	[IPSET_ATTR_FAMILY]	= { .type = NLA_U8 },
1532};
1533
1534static int
1535ip_set_type(struct sock *ctnl, struct sk_buff *skb,
1536	    const struct nlmsghdr *nlh,
1537	    const struct nlattr * const attr[])
1538{
1539	struct sk_buff *skb2;
1540	struct nlmsghdr *nlh2;
1541	u8 family, min, max;
1542	const char *typename;
1543	int ret = 0;
1544
1545	if (unlikely(protocol_failed(attr) ||
1546		     attr[IPSET_ATTR_TYPENAME] == NULL ||
1547		     attr[IPSET_ATTR_FAMILY] == NULL))
1548		return -IPSET_ERR_PROTOCOL;
1549
1550	family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1551	typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1552	ret = find_set_type_minmax(typename, family, &min, &max);
1553	if (ret)
1554		return ret;
1555
1556	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1557	if (skb2 == NULL)
1558		return -ENOMEM;
1559
1560	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1561			 IPSET_CMD_TYPE);
1562	if (!nlh2)
1563		goto nlmsg_failure;
1564	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL) ||
1565	    nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1566	    nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1567	    nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1568	    nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1569		goto nla_put_failure;
1570	nlmsg_end(skb2, nlh2);
1571
1572	pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1573	ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1574	if (ret < 0)
1575		return ret;
1576
1577	return 0;
1578
1579nla_put_failure:
1580	nlmsg_cancel(skb2, nlh2);
1581nlmsg_failure:
1582	kfree_skb(skb2);
1583	return -EMSGSIZE;
1584}
1585
1586/* Get protocol version */
1587
1588static const struct nla_policy
1589ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1590	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1591};
1592
1593static int
1594ip_set_protocol(struct sock *ctnl, struct sk_buff *skb,
1595		const struct nlmsghdr *nlh,
1596		const struct nlattr * const attr[])
1597{
1598	struct sk_buff *skb2;
1599	struct nlmsghdr *nlh2;
1600	int ret = 0;
1601
1602	if (unlikely(attr[IPSET_ATTR_PROTOCOL] == NULL))
1603		return -IPSET_ERR_PROTOCOL;
1604
1605	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1606	if (skb2 == NULL)
1607		return -ENOMEM;
1608
1609	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1610			 IPSET_CMD_PROTOCOL);
1611	if (!nlh2)
1612		goto nlmsg_failure;
1613	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
1614		goto nla_put_failure;
1615	nlmsg_end(skb2, nlh2);
1616
1617	ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1618	if (ret < 0)
1619		return ret;
1620
1621	return 0;
1622
1623nla_put_failure:
1624	nlmsg_cancel(skb2, nlh2);
1625nlmsg_failure:
1626	kfree_skb(skb2);
1627	return -EMSGSIZE;
1628}
1629
1630static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1631	[IPSET_CMD_NONE]	= {
1632		.call		= ip_set_none,
1633		.attr_count	= IPSET_ATTR_CMD_MAX,
1634	},
1635	[IPSET_CMD_CREATE]	= {
1636		.call		= ip_set_create,
1637		.attr_count	= IPSET_ATTR_CMD_MAX,
1638		.policy		= ip_set_create_policy,
1639	},
1640	[IPSET_CMD_DESTROY]	= {
1641		.call		= ip_set_destroy,
1642		.attr_count	= IPSET_ATTR_CMD_MAX,
1643		.policy		= ip_set_setname_policy,
1644	},
1645	[IPSET_CMD_FLUSH]	= {
1646		.call		= ip_set_flush,
1647		.attr_count	= IPSET_ATTR_CMD_MAX,
1648		.policy		= ip_set_setname_policy,
1649	},
1650	[IPSET_CMD_RENAME]	= {
1651		.call		= ip_set_rename,
1652		.attr_count	= IPSET_ATTR_CMD_MAX,
1653		.policy		= ip_set_setname2_policy,
1654	},
1655	[IPSET_CMD_SWAP]	= {
1656		.call		= ip_set_swap,
1657		.attr_count	= IPSET_ATTR_CMD_MAX,
1658		.policy		= ip_set_setname2_policy,
1659	},
1660	[IPSET_CMD_LIST]	= {
1661		.call		= ip_set_dump,
1662		.attr_count	= IPSET_ATTR_CMD_MAX,
1663		.policy		= ip_set_setname_policy,
1664	},
1665	[IPSET_CMD_SAVE]	= {
1666		.call		= ip_set_dump,
1667		.attr_count	= IPSET_ATTR_CMD_MAX,
1668		.policy		= ip_set_setname_policy,
1669	},
1670	[IPSET_CMD_ADD]	= {
1671		.call		= ip_set_uadd,
1672		.attr_count	= IPSET_ATTR_CMD_MAX,
1673		.policy		= ip_set_adt_policy,
1674	},
1675	[IPSET_CMD_DEL]	= {
1676		.call		= ip_set_udel,
1677		.attr_count	= IPSET_ATTR_CMD_MAX,
1678		.policy		= ip_set_adt_policy,
1679	},
1680	[IPSET_CMD_TEST]	= {
1681		.call		= ip_set_utest,
1682		.attr_count	= IPSET_ATTR_CMD_MAX,
1683		.policy		= ip_set_adt_policy,
1684	},
1685	[IPSET_CMD_HEADER]	= {
1686		.call		= ip_set_header,
1687		.attr_count	= IPSET_ATTR_CMD_MAX,
1688		.policy		= ip_set_setname_policy,
1689	},
1690	[IPSET_CMD_TYPE]	= {
1691		.call		= ip_set_type,
1692		.attr_count	= IPSET_ATTR_CMD_MAX,
1693		.policy		= ip_set_type_policy,
1694	},
1695	[IPSET_CMD_PROTOCOL]	= {
1696		.call		= ip_set_protocol,
1697		.attr_count	= IPSET_ATTR_CMD_MAX,
1698		.policy		= ip_set_protocol_policy,
1699	},
1700};
1701
1702static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
1703	.name		= "ip_set",
1704	.subsys_id	= NFNL_SUBSYS_IPSET,
1705	.cb_count	= IPSET_MSG_MAX,
1706	.cb		= ip_set_netlink_subsys_cb,
1707};
1708
1709/* Interface to iptables/ip6tables */
1710
1711static int
1712ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
1713{
1714	unsigned int *op;
1715	void *data;
1716	int copylen = *len, ret = 0;
1717
1718	if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
1719		return -EPERM;
1720	if (optval != SO_IP_SET)
1721		return -EBADF;
1722	if (*len < sizeof(unsigned int))
1723		return -EINVAL;
1724
1725	data = vmalloc(*len);
1726	if (!data)
1727		return -ENOMEM;
1728	if (copy_from_user(data, user, *len) != 0) {
1729		ret = -EFAULT;
1730		goto done;
1731	}
1732	op = (unsigned int *) data;
1733
1734	if (*op < IP_SET_OP_VERSION) {
1735		/* Check the version at the beginning of operations */
1736		struct ip_set_req_version *req_version = data;
1737		if (req_version->version != IPSET_PROTOCOL) {
1738			ret = -EPROTO;
1739			goto done;
1740		}
1741	}
1742
1743	switch (*op) {
1744	case IP_SET_OP_VERSION: {
1745		struct ip_set_req_version *req_version = data;
1746
1747		if (*len != sizeof(struct ip_set_req_version)) {
1748			ret = -EINVAL;
1749			goto done;
1750		}
1751
1752		req_version->version = IPSET_PROTOCOL;
1753		ret = copy_to_user(user, req_version,
1754				   sizeof(struct ip_set_req_version));
1755		goto done;
1756	}
1757	case IP_SET_OP_GET_BYNAME: {
1758		struct ip_set_req_get_set *req_get = data;
1759		ip_set_id_t id;
1760
1761		if (*len != sizeof(struct ip_set_req_get_set)) {
1762			ret = -EINVAL;
1763			goto done;
1764		}
1765		req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
1766		nfnl_lock();
1767		find_set_and_id(req_get->set.name, &id);
1768		req_get->set.index = id;
1769		nfnl_unlock();
1770		goto copy;
1771	}
1772	case IP_SET_OP_GET_BYINDEX: {
1773		struct ip_set_req_get_set *req_get = data;
1774		struct ip_set *set;
1775
1776		if (*len != sizeof(struct ip_set_req_get_set) ||
1777		    req_get->set.index >= ip_set_max) {
1778			ret = -EINVAL;
1779			goto done;
1780		}
1781		nfnl_lock();
1782		set = nfnl_set(req_get->set.index);
1783		strncpy(req_get->set.name, set ? set->name : "",
1784			IPSET_MAXNAMELEN);
1785		nfnl_unlock();
1786		goto copy;
1787	}
1788	default:
1789		ret = -EBADMSG;
1790		goto done;
1791	}	/* end of switch(op) */
1792
1793copy:
1794	ret = copy_to_user(user, data, copylen);
1795
1796done:
1797	vfree(data);
1798	if (ret > 0)
1799		ret = 0;
1800	return ret;
1801}
1802
1803static struct nf_sockopt_ops so_set __read_mostly = {
1804	.pf		= PF_INET,
1805	.get_optmin	= SO_IP_SET,
1806	.get_optmax	= SO_IP_SET + 1,
1807	.get		= &ip_set_sockfn_get,
1808	.owner		= THIS_MODULE,
1809};
1810
1811static int __init
1812ip_set_init(void)
1813{
1814	struct ip_set **list;
1815	int ret;
1816
1817	if (max_sets)
1818		ip_set_max = max_sets;
1819	if (ip_set_max >= IPSET_INVALID_ID)
1820		ip_set_max = IPSET_INVALID_ID - 1;
1821
1822	list = kzalloc(sizeof(struct ip_set *) * ip_set_max, GFP_KERNEL);
1823	if (!list)
1824		return -ENOMEM;
1825
1826	rcu_assign_pointer(ip_set_list, list);
1827	ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
1828	if (ret != 0) {
1829		pr_err("ip_set: cannot register with nfnetlink.\n");
1830		kfree(list);
1831		return ret;
1832	}
1833	ret = nf_register_sockopt(&so_set);
1834	if (ret != 0) {
1835		pr_err("SO_SET registry failed: %d\n", ret);
1836		nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1837		kfree(list);
1838		return ret;
1839	}
1840
1841	pr_notice("ip_set: protocol %u\n", IPSET_PROTOCOL);
1842	return 0;
1843}
1844
1845static void __exit
1846ip_set_fini(void)
1847{
1848	struct ip_set **list = rcu_dereference_protected(ip_set_list, 1);
1849
1850	/* There can't be any existing set */
1851	nf_unregister_sockopt(&so_set);
1852	nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1853	kfree(list);
1854	pr_debug("these are the famous last words\n");
1855}
1856
1857module_init(ip_set_init);
1858module_exit(ip_set_fini);
1859