1/* Copyright (C) 2013 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2 *
3 * This program is free software; you can redistribute it and/or modify
4 * it under the terms of the GNU General Public License version 2 as
5 * published by the Free Software Foundation.
6 */
7
8#ifndef _IP_SET_HASH_GEN_H
9#define _IP_SET_HASH_GEN_H
10
11#include <linux/rcupdate.h>
12#include <linux/jhash.h>
13#include <linux/netfilter/ipset/ip_set_timeout.h>
14#ifndef rcu_dereference_bh
15#define rcu_dereference_bh(p)	rcu_dereference(p)
16#endif
17
18#define CONCAT(a, b)		a##b
19#define TOKEN(a, b)		CONCAT(a, b)
20
21/* Hashing which uses arrays to resolve clashing. The hash table is resized
22 * (doubled) when searching becomes too long.
23 * Internally jhash is used with the assumption that the size of the
24 * stored data is a multiple of sizeof(u32). If storage supports timeout,
25 * the timeout field must be the last one in the data structure - that field
26 * is ignored when computing the hash key.
27 *
28 * Readers and resizing
29 *
30 * Resizing can be triggered by userspace command only, and those
31 * are serialized by the nfnl mutex. During resizing the set is
32 * read-locked, so the only possible concurrent operations are
33 * the kernel side readers. Those must be protected by proper RCU locking.
34 */
35
36/* Number of elements to store in an initial array block */
37#define AHASH_INIT_SIZE			4
38/* Max number of elements to store in an array block */
39#define AHASH_MAX_SIZE			(3*AHASH_INIT_SIZE)
40
41/* Max number of elements can be tuned */
42#ifdef IP_SET_HASH_WITH_MULTI
43#define AHASH_MAX(h)			((h)->ahash_max)
44
45static inline u8
46tune_ahash_max(u8 curr, u32 multi)
47{
48	u32 n;
49
50	if (multi < curr)
51		return curr;
52
53	n = curr + AHASH_INIT_SIZE;
54	/* Currently, at listing one hash bucket must fit into a message.
55	 * Therefore we have a hard limit here.
56	 */
57	return n > curr && n <= 64 ? n : curr;
58}
59#define TUNE_AHASH_MAX(h, multi)	\
60	((h)->ahash_max = tune_ahash_max((h)->ahash_max, multi))
61#else
62#define AHASH_MAX(h)			AHASH_MAX_SIZE
63#define TUNE_AHASH_MAX(h, multi)
64#endif
65
66/* A hash bucket */
67struct hbucket {
68	void *value;		/* the array of the values */
69	u8 size;		/* size of the array */
70	u8 pos;			/* position of the first free entry */
71};
72
73/* The hash table: the table size stored here in order to make resizing easy */
74struct htable {
75	u8 htable_bits;		/* size of hash table == 2^htable_bits */
76	struct hbucket bucket[0]; /* hashtable buckets */
77};
78
79#define hbucket(h, i)		(&((h)->bucket[i]))
80
81/* Book-keeping of the prefixes added to the set */
82struct net_prefixes {
83	u8 cidr;		/* the different cidr values in the set */
84	u32 nets;		/* number of elements per cidr */
85};
86
87/* Compute the hash table size */
88static size_t
89htable_size(u8 hbits)
90{
91	size_t hsize;
92
93	/* We must fit both into u32 in jhash and size_t */
94	if (hbits > 31)
95		return 0;
96	hsize = jhash_size(hbits);
97	if ((((size_t)-1) - sizeof(struct htable))/sizeof(struct hbucket)
98	    < hsize)
99		return 0;
100
101	return hsize * sizeof(struct hbucket) + sizeof(struct htable);
102}
103
104/* Compute htable_bits from the user input parameter hashsize */
105static u8
106htable_bits(u32 hashsize)
107{
108	/* Assume that hashsize == 2^htable_bits */
109	u8 bits = fls(hashsize - 1);
110	if (jhash_size(bits) != hashsize)
111		/* Round up to the first 2^n value */
112		bits = fls(hashsize);
113
114	return bits;
115}
116
117/* Destroy the hashtable part of the set */
118static void
119ahash_destroy(struct htable *t)
120{
121	struct hbucket *n;
122	u32 i;
123
124	for (i = 0; i < jhash_size(t->htable_bits); i++) {
125		n = hbucket(t, i);
126		if (n->size)
127			/* FIXME: use slab cache */
128			kfree(n->value);
129	}
130
131	ip_set_free(t);
132}
133
134static int
135hbucket_elem_add(struct hbucket *n, u8 ahash_max, size_t dsize)
136{
137	if (n->pos >= n->size) {
138		void *tmp;
139
140		if (n->size >= ahash_max)
141			/* Trigger rehashing */
142			return -EAGAIN;
143
144		tmp = kzalloc((n->size + AHASH_INIT_SIZE) * dsize,
145			      GFP_ATOMIC);
146		if (!tmp)
147			return -ENOMEM;
148		if (n->size) {
149			memcpy(tmp, n->value, n->size * dsize);
150			kfree(n->value);
151		}
152		n->value = tmp;
153		n->size += AHASH_INIT_SIZE;
154	}
155	return 0;
156}
157
158#ifdef IP_SET_HASH_WITH_NETS
159#ifdef IP_SET_HASH_WITH_NETS_PACKED
160/* When cidr is packed with nomatch, cidr - 1 is stored in the entry */
161#define CIDR(cidr)		(cidr + 1)
162#else
163#define CIDR(cidr)		(cidr)
164#endif
165
166#define SET_HOST_MASK(family)	(family == AF_INET ? 32 : 128)
167
168#ifdef IP_SET_HASH_WITH_MULTI
169#define NETS_LENGTH(family)	(SET_HOST_MASK(family) + 1)
170#else
171#define NETS_LENGTH(family)	SET_HOST_MASK(family)
172#endif
173
174#else
175#define NETS_LENGTH(family)	0
176#endif /* IP_SET_HASH_WITH_NETS */
177
178#define ext_timeout(e, h)	\
179(unsigned long *)(((void *)(e)) + (h)->offset[IPSET_OFFSET_TIMEOUT])
180#define ext_counter(e, h)	\
181(struct ip_set_counter *)(((void *)(e)) + (h)->offset[IPSET_OFFSET_COUNTER])
182
183#endif /* _IP_SET_HASH_GEN_H */
184
185/* Family dependent templates */
186
187#undef ahash_data
188#undef mtype_data_equal
189#undef mtype_do_data_match
190#undef mtype_data_set_flags
191#undef mtype_data_reset_flags
192#undef mtype_data_netmask
193#undef mtype_data_list
194#undef mtype_data_next
195#undef mtype_elem
196
197#undef mtype_add_cidr
198#undef mtype_del_cidr
199#undef mtype_ahash_memsize
200#undef mtype_flush
201#undef mtype_destroy
202#undef mtype_gc_init
203#undef mtype_same_set
204#undef mtype_kadt
205#undef mtype_uadt
206#undef mtype
207
208#undef mtype_add
209#undef mtype_del
210#undef mtype_test_cidrs
211#undef mtype_test
212#undef mtype_expire
213#undef mtype_resize
214#undef mtype_head
215#undef mtype_list
216#undef mtype_gc
217#undef mtype_gc_init
218#undef mtype_variant
219#undef mtype_data_match
220
221#undef HKEY
222
223#define mtype_data_equal	TOKEN(MTYPE, _data_equal)
224#ifdef IP_SET_HASH_WITH_NETS
225#define mtype_do_data_match	TOKEN(MTYPE, _do_data_match)
226#else
227#define mtype_do_data_match(d)	1
228#endif
229#define mtype_data_set_flags	TOKEN(MTYPE, _data_set_flags)
230#define mtype_data_reset_flags	TOKEN(MTYPE, _data_reset_flags)
231#define mtype_data_netmask	TOKEN(MTYPE, _data_netmask)
232#define mtype_data_list		TOKEN(MTYPE, _data_list)
233#define mtype_data_next		TOKEN(MTYPE, _data_next)
234#define mtype_elem		TOKEN(MTYPE, _elem)
235#define mtype_add_cidr		TOKEN(MTYPE, _add_cidr)
236#define mtype_del_cidr		TOKEN(MTYPE, _del_cidr)
237#define mtype_ahash_memsize	TOKEN(MTYPE, _ahash_memsize)
238#define mtype_flush		TOKEN(MTYPE, _flush)
239#define mtype_destroy		TOKEN(MTYPE, _destroy)
240#define mtype_gc_init		TOKEN(MTYPE, _gc_init)
241#define mtype_same_set		TOKEN(MTYPE, _same_set)
242#define mtype_kadt		TOKEN(MTYPE, _kadt)
243#define mtype_uadt		TOKEN(MTYPE, _uadt)
244#define mtype			MTYPE
245
246#define mtype_elem		TOKEN(MTYPE, _elem)
247#define mtype_add		TOKEN(MTYPE, _add)
248#define mtype_del		TOKEN(MTYPE, _del)
249#define mtype_test_cidrs	TOKEN(MTYPE, _test_cidrs)
250#define mtype_test		TOKEN(MTYPE, _test)
251#define mtype_expire		TOKEN(MTYPE, _expire)
252#define mtype_resize		TOKEN(MTYPE, _resize)
253#define mtype_head		TOKEN(MTYPE, _head)
254#define mtype_list		TOKEN(MTYPE, _list)
255#define mtype_gc		TOKEN(MTYPE, _gc)
256#define mtype_variant		TOKEN(MTYPE, _variant)
257#define mtype_data_match	TOKEN(MTYPE, _data_match)
258
259#ifndef HKEY_DATALEN
260#define HKEY_DATALEN		sizeof(struct mtype_elem)
261#endif
262
263#define HKEY(data, initval, htable_bits)			\
264(jhash2((u32 *)(data), HKEY_DATALEN/sizeof(u32), initval)	\
265	& jhash_mask(htable_bits))
266
267#ifndef htype
268#define htype			HTYPE
269
270/* The generic hash structure */
271struct htype {
272	struct htable *table;	/* the hash table */
273	u32 maxelem;		/* max elements in the hash */
274	u32 elements;		/* current element (vs timeout) */
275	u32 initval;		/* random jhash init value */
276	u32 timeout;		/* timeout value, if enabled */
277	size_t dsize;		/* data struct size */
278	size_t offset[IPSET_OFFSET_MAX]; /* Offsets to extensions */
279	struct timer_list gc;	/* garbage collection when timeout enabled */
280	struct mtype_elem next; /* temporary storage for uadd */
281#ifdef IP_SET_HASH_WITH_MULTI
282	u8 ahash_max;		/* max elements in an array block */
283#endif
284#ifdef IP_SET_HASH_WITH_NETMASK
285	u8 netmask;		/* netmask value for subnets to store */
286#endif
287#ifdef IP_SET_HASH_WITH_RBTREE
288	struct rb_root rbtree;
289#endif
290#ifdef IP_SET_HASH_WITH_NETS
291	struct net_prefixes nets[0]; /* book-keeping of prefixes */
292#endif
293};
294#endif
295
296#ifdef IP_SET_HASH_WITH_NETS
297/* Network cidr size book keeping when the hash stores different
298 * sized networks */
299static void
300mtype_add_cidr(struct htype *h, u8 cidr, u8 nets_length)
301{
302	int i, j;
303
304	/* Add in increasing prefix order, so larger cidr first */
305	for (i = 0, j = -1; i < nets_length && h->nets[i].nets; i++) {
306		if (j != -1)
307			continue;
308		else if (h->nets[i].cidr < cidr)
309			j = i;
310		else if (h->nets[i].cidr == cidr) {
311			h->nets[i].nets++;
312			return;
313		}
314	}
315	if (j != -1) {
316		for (; i > j; i--) {
317			h->nets[i].cidr = h->nets[i - 1].cidr;
318			h->nets[i].nets = h->nets[i - 1].nets;
319		}
320	}
321	h->nets[i].cidr = cidr;
322	h->nets[i].nets = 1;
323}
324
325static void
326mtype_del_cidr(struct htype *h, u8 cidr, u8 nets_length)
327{
328	u8 i, j;
329
330	for (i = 0; i < nets_length - 1 && h->nets[i].cidr != cidr; i++)
331		;
332	h->nets[i].nets--;
333
334	if (h->nets[i].nets != 0)
335		return;
336
337	for (j = i; j < nets_length - 1 && h->nets[j].nets; j++) {
338		h->nets[j].cidr = h->nets[j + 1].cidr;
339		h->nets[j].nets = h->nets[j + 1].nets;
340	}
341}
342#endif
343
344/* Calculate the actual memory size of the set data */
345static size_t
346mtype_ahash_memsize(const struct htype *h, u8 nets_length)
347{
348	u32 i;
349	struct htable *t = h->table;
350	size_t memsize = sizeof(*h)
351			 + sizeof(*t)
352#ifdef IP_SET_HASH_WITH_NETS
353			 + sizeof(struct net_prefixes) * nets_length
354#endif
355			 + jhash_size(t->htable_bits) * sizeof(struct hbucket);
356
357	for (i = 0; i < jhash_size(t->htable_bits); i++)
358		memsize += t->bucket[i].size * h->dsize;
359
360	return memsize;
361}
362
363/* Flush a hash type of set: destroy all elements */
364static void
365mtype_flush(struct ip_set *set)
366{
367	struct htype *h = set->data;
368	struct htable *t = h->table;
369	struct hbucket *n;
370	u32 i;
371
372	for (i = 0; i < jhash_size(t->htable_bits); i++) {
373		n = hbucket(t, i);
374		if (n->size) {
375			n->size = n->pos = 0;
376			/* FIXME: use slab cache */
377			kfree(n->value);
378		}
379	}
380#ifdef IP_SET_HASH_WITH_NETS
381	memset(h->nets, 0, sizeof(struct net_prefixes)
382			   * NETS_LENGTH(set->family));
383#endif
384	h->elements = 0;
385}
386
387/* Destroy a hash type of set */
388static void
389mtype_destroy(struct ip_set *set)
390{
391	struct htype *h = set->data;
392
393	if (set->extensions & IPSET_EXT_TIMEOUT)
394		del_timer_sync(&h->gc);
395
396	ahash_destroy(h->table);
397#ifdef IP_SET_HASH_WITH_RBTREE
398	rbtree_destroy(&h->rbtree);
399#endif
400	kfree(h);
401
402	set->data = NULL;
403}
404
405static void
406mtype_gc_init(struct ip_set *set, void (*gc)(unsigned long ul_set))
407{
408	struct htype *h = set->data;
409
410	init_timer(&h->gc);
411	h->gc.data = (unsigned long) set;
412	h->gc.function = gc;
413	h->gc.expires = jiffies + IPSET_GC_PERIOD(h->timeout) * HZ;
414	add_timer(&h->gc);
415	pr_debug("gc initialized, run in every %u\n",
416		 IPSET_GC_PERIOD(h->timeout));
417}
418
419static bool
420mtype_same_set(const struct ip_set *a, const struct ip_set *b)
421{
422	const struct htype *x = a->data;
423	const struct htype *y = b->data;
424
425	/* Resizing changes htable_bits, so we ignore it */
426	return x->maxelem == y->maxelem &&
427	       x->timeout == y->timeout &&
428#ifdef IP_SET_HASH_WITH_NETMASK
429	       x->netmask == y->netmask &&
430#endif
431	       a->extensions == b->extensions;
432}
433
434/* Get the ith element from the array block n */
435#define ahash_data(n, i, dsize)	\
436	((struct mtype_elem *)((n)->value + ((i) * (dsize))))
437
438/* Delete expired elements from the hashtable */
439static void
440mtype_expire(struct htype *h, u8 nets_length, size_t dsize)
441{
442	struct htable *t = h->table;
443	struct hbucket *n;
444	struct mtype_elem *data;
445	u32 i;
446	int j;
447
448	for (i = 0; i < jhash_size(t->htable_bits); i++) {
449		n = hbucket(t, i);
450		for (j = 0; j < n->pos; j++) {
451			data = ahash_data(n, j, dsize);
452			if (ip_set_timeout_expired(ext_timeout(data, h))) {
453				pr_debug("expired %u/%u\n", i, j);
454#ifdef IP_SET_HASH_WITH_NETS
455				mtype_del_cidr(h, CIDR(data->cidr),
456					       nets_length);
457#endif
458				if (j != n->pos - 1)
459					/* Not last one */
460					memcpy(data,
461					       ahash_data(n, n->pos - 1, dsize),
462					       dsize);
463				n->pos--;
464				h->elements--;
465			}
466		}
467		if (n->pos + AHASH_INIT_SIZE < n->size) {
468			void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
469					    * dsize,
470					    GFP_ATOMIC);
471			if (!tmp)
472				/* Still try to delete expired elements */
473				continue;
474			n->size -= AHASH_INIT_SIZE;
475			memcpy(tmp, n->value, n->size * dsize);
476			kfree(n->value);
477			n->value = tmp;
478		}
479	}
480}
481
482static void
483mtype_gc(unsigned long ul_set)
484{
485	struct ip_set *set = (struct ip_set *) ul_set;
486	struct htype *h = set->data;
487
488	pr_debug("called\n");
489	write_lock_bh(&set->lock);
490	mtype_expire(h, NETS_LENGTH(set->family), h->dsize);
491	write_unlock_bh(&set->lock);
492
493	h->gc.expires = jiffies + IPSET_GC_PERIOD(h->timeout) * HZ;
494	add_timer(&h->gc);
495}
496
497/* Resize a hash: create a new hash table with doubling the hashsize
498 * and inserting the elements to it. Repeat until we succeed or
499 * fail due to memory pressures. */
500static int
501mtype_resize(struct ip_set *set, bool retried)
502{
503	struct htype *h = set->data;
504	struct htable *t, *orig = h->table;
505	u8 htable_bits = orig->htable_bits;
506#ifdef IP_SET_HASH_WITH_NETS
507	u8 flags;
508#endif
509	struct mtype_elem *data;
510	struct mtype_elem *d;
511	struct hbucket *n, *m;
512	u32 i, j;
513	int ret;
514
515	/* Try to cleanup once */
516	if (SET_WITH_TIMEOUT(set) && !retried) {
517		i = h->elements;
518		write_lock_bh(&set->lock);
519		mtype_expire(set->data, NETS_LENGTH(set->family),
520			     h->dsize);
521		write_unlock_bh(&set->lock);
522		if (h->elements < i)
523			return 0;
524	}
525
526retry:
527	ret = 0;
528	htable_bits++;
529	pr_debug("attempt to resize set %s from %u to %u, t %p\n",
530		 set->name, orig->htable_bits, htable_bits, orig);
531	if (!htable_bits) {
532		/* In case we have plenty of memory :-) */
533		pr_warning("Cannot increase the hashsize of set %s further\n",
534			   set->name);
535		return -IPSET_ERR_HASH_FULL;
536	}
537	t = ip_set_alloc(sizeof(*t)
538			 + jhash_size(htable_bits) * sizeof(struct hbucket));
539	if (!t)
540		return -ENOMEM;
541	t->htable_bits = htable_bits;
542
543	read_lock_bh(&set->lock);
544	for (i = 0; i < jhash_size(orig->htable_bits); i++) {
545		n = hbucket(orig, i);
546		for (j = 0; j < n->pos; j++) {
547			data = ahash_data(n, j, h->dsize);
548#ifdef IP_SET_HASH_WITH_NETS
549			flags = 0;
550			mtype_data_reset_flags(data, &flags);
551#endif
552			m = hbucket(t, HKEY(data, h->initval, htable_bits));
553			ret = hbucket_elem_add(m, AHASH_MAX(h), h->dsize);
554			if (ret < 0) {
555#ifdef IP_SET_HASH_WITH_NETS
556				mtype_data_reset_flags(data, &flags);
557#endif
558				read_unlock_bh(&set->lock);
559				ahash_destroy(t);
560				if (ret == -EAGAIN)
561					goto retry;
562				return ret;
563			}
564			d = ahash_data(m, m->pos++, h->dsize);
565			memcpy(d, data, h->dsize);
566#ifdef IP_SET_HASH_WITH_NETS
567			mtype_data_reset_flags(d, &flags);
568#endif
569		}
570	}
571
572	rcu_assign_pointer(h->table, t);
573	read_unlock_bh(&set->lock);
574
575	/* Give time to other readers of the set */
576	synchronize_rcu_bh();
577
578	pr_debug("set %s resized from %u (%p) to %u (%p)\n", set->name,
579		 orig->htable_bits, orig, t->htable_bits, t);
580	ahash_destroy(orig);
581
582	return 0;
583}
584
585/* Add an element to a hash and update the internal counters when succeeded,
586 * otherwise report the proper error code. */
587static int
588mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
589	  struct ip_set_ext *mext, u32 flags)
590{
591	struct htype *h = set->data;
592	struct htable *t;
593	const struct mtype_elem *d = value;
594	struct mtype_elem *data;
595	struct hbucket *n;
596	int i, ret = 0;
597	int j = AHASH_MAX(h) + 1;
598	bool flag_exist = flags & IPSET_FLAG_EXIST;
599	u32 key, multi = 0;
600
601	if (SET_WITH_TIMEOUT(set) && h->elements >= h->maxelem)
602		/* FIXME: when set is full, we slow down here */
603		mtype_expire(h, NETS_LENGTH(set->family), h->dsize);
604
605	if (h->elements >= h->maxelem) {
606		if (net_ratelimit())
607			pr_warning("Set %s is full, maxelem %u reached\n",
608				   set->name, h->maxelem);
609		return -IPSET_ERR_HASH_FULL;
610	}
611
612	rcu_read_lock_bh();
613	t = rcu_dereference_bh(h->table);
614	key = HKEY(value, h->initval, t->htable_bits);
615	n = hbucket(t, key);
616	for (i = 0; i < n->pos; i++) {
617		data = ahash_data(n, i, h->dsize);
618		if (mtype_data_equal(data, d, &multi)) {
619			if (flag_exist ||
620			    (SET_WITH_TIMEOUT(set) &&
621			     ip_set_timeout_expired(ext_timeout(data, h)))) {
622				/* Just the extensions could be overwritten */
623				j = i;
624				goto reuse_slot;
625			} else {
626				ret = -IPSET_ERR_EXIST;
627				goto out;
628			}
629		}
630		/* Reuse first timed out entry */
631		if (SET_WITH_TIMEOUT(set) &&
632		    ip_set_timeout_expired(ext_timeout(data, h)) &&
633		    j != AHASH_MAX(h) + 1)
634			j = i;
635	}
636reuse_slot:
637	if (j != AHASH_MAX(h) + 1) {
638		/* Fill out reused slot */
639		data = ahash_data(n, j, h->dsize);
640#ifdef IP_SET_HASH_WITH_NETS
641		mtype_del_cidr(h, CIDR(data->cidr), NETS_LENGTH(set->family));
642		mtype_add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
643#endif
644	} else {
645		/* Use/create a new slot */
646		TUNE_AHASH_MAX(h, multi);
647		ret = hbucket_elem_add(n, AHASH_MAX(h), h->dsize);
648		if (ret != 0) {
649			if (ret == -EAGAIN)
650				mtype_data_next(&h->next, d);
651			goto out;
652		}
653		data = ahash_data(n, n->pos++, h->dsize);
654#ifdef IP_SET_HASH_WITH_NETS
655		mtype_add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
656#endif
657		h->elements++;
658	}
659	memcpy(data, d, sizeof(struct mtype_elem));
660#ifdef IP_SET_HASH_WITH_NETS
661	mtype_data_set_flags(data, flags);
662#endif
663	if (SET_WITH_TIMEOUT(set))
664		ip_set_timeout_set(ext_timeout(data, h), ext->timeout);
665	if (SET_WITH_COUNTER(set))
666		ip_set_init_counter(ext_counter(data, h), ext);
667
668out:
669	rcu_read_unlock_bh();
670	return ret;
671}
672
673/* Delete an element from the hash: swap it with the last element
674 * and free up space if possible.
675 */
676static int
677mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext,
678	  struct ip_set_ext *mext, u32 flags)
679{
680	struct htype *h = set->data;
681	struct htable *t = h->table;
682	const struct mtype_elem *d = value;
683	struct mtype_elem *data;
684	struct hbucket *n;
685	int i;
686	u32 key, multi = 0;
687
688	key = HKEY(value, h->initval, t->htable_bits);
689	n = hbucket(t, key);
690	for (i = 0; i < n->pos; i++) {
691		data = ahash_data(n, i, h->dsize);
692		if (!mtype_data_equal(data, d, &multi))
693			continue;
694		if (SET_WITH_TIMEOUT(set) &&
695		    ip_set_timeout_expired(ext_timeout(data, h)))
696			return -IPSET_ERR_EXIST;
697		if (i != n->pos - 1)
698			/* Not last one */
699			memcpy(data, ahash_data(n, n->pos - 1, h->dsize),
700			       h->dsize);
701
702		n->pos--;
703		h->elements--;
704#ifdef IP_SET_HASH_WITH_NETS
705		mtype_del_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
706#endif
707		if (n->pos + AHASH_INIT_SIZE < n->size) {
708			void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
709					    * h->dsize,
710					    GFP_ATOMIC);
711			if (!tmp)
712				return 0;
713			n->size -= AHASH_INIT_SIZE;
714			memcpy(tmp, n->value, n->size * h->dsize);
715			kfree(n->value);
716			n->value = tmp;
717		}
718		return 0;
719	}
720
721	return -IPSET_ERR_EXIST;
722}
723
724static inline int
725mtype_data_match(struct mtype_elem *data, const struct ip_set_ext *ext,
726		 struct ip_set_ext *mext, struct ip_set *set, u32 flags)
727{
728	if (SET_WITH_COUNTER(set))
729		ip_set_update_counter(ext_counter(data,
730						  (struct htype *)(set->data)),
731				      ext, mext, flags);
732	return mtype_do_data_match(data);
733}
734
735#ifdef IP_SET_HASH_WITH_NETS
736/* Special test function which takes into account the different network
737 * sizes added to the set */
738static int
739mtype_test_cidrs(struct ip_set *set, struct mtype_elem *d,
740		 const struct ip_set_ext *ext,
741		 struct ip_set_ext *mext, u32 flags)
742{
743	struct htype *h = set->data;
744	struct htable *t = h->table;
745	struct hbucket *n;
746	struct mtype_elem *data;
747	int i, j = 0;
748	u32 key, multi = 0;
749	u8 nets_length = NETS_LENGTH(set->family);
750
751	pr_debug("test by nets\n");
752	for (; j < nets_length && h->nets[j].nets && !multi; j++) {
753		mtype_data_netmask(d, h->nets[j].cidr);
754		key = HKEY(d, h->initval, t->htable_bits);
755		n = hbucket(t, key);
756		for (i = 0; i < n->pos; i++) {
757			data = ahash_data(n, i, h->dsize);
758			if (!mtype_data_equal(data, d, &multi))
759				continue;
760			if (SET_WITH_TIMEOUT(set)) {
761				if (!ip_set_timeout_expired(
762							ext_timeout(data, h)))
763					return mtype_data_match(data, ext,
764								mext, set,
765								flags);
766#ifdef IP_SET_HASH_WITH_MULTI
767				multi = 0;
768#endif
769			} else
770				return mtype_data_match(data, ext,
771							mext, set, flags);
772		}
773	}
774	return 0;
775}
776#endif
777
778/* Test whether the element is added to the set */
779static int
780mtype_test(struct ip_set *set, void *value, const struct ip_set_ext *ext,
781	   struct ip_set_ext *mext, u32 flags)
782{
783	struct htype *h = set->data;
784	struct htable *t = h->table;
785	struct mtype_elem *d = value;
786	struct hbucket *n;
787	struct mtype_elem *data;
788	int i;
789	u32 key, multi = 0;
790
791#ifdef IP_SET_HASH_WITH_NETS
792	/* If we test an IP address and not a network address,
793	 * try all possible network sizes */
794	if (CIDR(d->cidr) == SET_HOST_MASK(set->family))
795		return mtype_test_cidrs(set, d, ext, mext, flags);
796#endif
797
798	key = HKEY(d, h->initval, t->htable_bits);
799	n = hbucket(t, key);
800	for (i = 0; i < n->pos; i++) {
801		data = ahash_data(n, i, h->dsize);
802		if (mtype_data_equal(data, d, &multi) &&
803		    !(SET_WITH_TIMEOUT(set) &&
804		      ip_set_timeout_expired(ext_timeout(data, h))))
805			return mtype_data_match(data, ext, mext, set, flags);
806	}
807	return 0;
808}
809
810/* Reply a HEADER request: fill out the header part of the set */
811static int
812mtype_head(struct ip_set *set, struct sk_buff *skb)
813{
814	const struct htype *h = set->data;
815	struct nlattr *nested;
816	size_t memsize;
817
818	read_lock_bh(&set->lock);
819	memsize = mtype_ahash_memsize(h, NETS_LENGTH(set->family));
820	read_unlock_bh(&set->lock);
821
822	nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
823	if (!nested)
824		goto nla_put_failure;
825	if (nla_put_net32(skb, IPSET_ATTR_HASHSIZE,
826			  htonl(jhash_size(h->table->htable_bits))) ||
827	    nla_put_net32(skb, IPSET_ATTR_MAXELEM, htonl(h->maxelem)))
828		goto nla_put_failure;
829#ifdef IP_SET_HASH_WITH_NETMASK
830	if (h->netmask != HOST_MASK &&
831	    nla_put_u8(skb, IPSET_ATTR_NETMASK, h->netmask))
832		goto nla_put_failure;
833#endif
834	if (nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref - 1)) ||
835	    nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) ||
836	    ((set->extensions & IPSET_EXT_TIMEOUT) &&
837	     nla_put_net32(skb, IPSET_ATTR_TIMEOUT, htonl(h->timeout))) ||
838	    ((set->extensions & IPSET_EXT_COUNTER) &&
839	     nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS,
840			   htonl(IPSET_FLAG_WITH_COUNTERS))))
841		goto nla_put_failure;
842	ipset_nest_end(skb, nested);
843
844	return 0;
845nla_put_failure:
846	return -EMSGSIZE;
847}
848
849/* Reply a LIST/SAVE request: dump the elements of the specified set */
850static int
851mtype_list(const struct ip_set *set,
852	   struct sk_buff *skb, struct netlink_callback *cb)
853{
854	const struct htype *h = set->data;
855	const struct htable *t = h->table;
856	struct nlattr *atd, *nested;
857	const struct hbucket *n;
858	const struct mtype_elem *e;
859	u32 first = cb->args[2];
860	/* We assume that one hash bucket fills into one page */
861	void *incomplete;
862	int i;
863
864	atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
865	if (!atd)
866		return -EMSGSIZE;
867	pr_debug("list hash set %s\n", set->name);
868	for (; cb->args[2] < jhash_size(t->htable_bits); cb->args[2]++) {
869		incomplete = skb_tail_pointer(skb);
870		n = hbucket(t, cb->args[2]);
871		pr_debug("cb->args[2]: %lu, t %p n %p\n", cb->args[2], t, n);
872		for (i = 0; i < n->pos; i++) {
873			e = ahash_data(n, i, h->dsize);
874			if (SET_WITH_TIMEOUT(set) &&
875			    ip_set_timeout_expired(ext_timeout(e, h)))
876				continue;
877			pr_debug("list hash %lu hbucket %p i %u, data %p\n",
878				 cb->args[2], n, i, e);
879			nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
880			if (!nested) {
881				if (cb->args[2] == first) {
882					nla_nest_cancel(skb, atd);
883					return -EMSGSIZE;
884				} else
885					goto nla_put_failure;
886			}
887			if (mtype_data_list(skb, e))
888				goto nla_put_failure;
889			if (SET_WITH_TIMEOUT(set) &&
890			    nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
891					  htonl(ip_set_timeout_get(
892						ext_timeout(e, h)))))
893				goto nla_put_failure;
894			if (SET_WITH_COUNTER(set) &&
895			    ip_set_put_counter(skb, ext_counter(e, h)))
896				goto nla_put_failure;
897			ipset_nest_end(skb, nested);
898		}
899	}
900	ipset_nest_end(skb, atd);
901	/* Set listing finished */
902	cb->args[2] = 0;
903
904	return 0;
905
906nla_put_failure:
907	nlmsg_trim(skb, incomplete);
908	ipset_nest_end(skb, atd);
909	if (unlikely(first == cb->args[2])) {
910		pr_warning("Can't list set %s: one bucket does not fit into "
911			   "a message. Please report it!\n", set->name);
912		cb->args[2] = 0;
913		return -EMSGSIZE;
914	}
915	return 0;
916}
917
918static int
919TOKEN(MTYPE, _kadt)(struct ip_set *set, const struct sk_buff *skb,
920	      const struct xt_action_param *par,
921	      enum ipset_adt adt, struct ip_set_adt_opt *opt);
922
923static int
924TOKEN(MTYPE, _uadt)(struct ip_set *set, struct nlattr *tb[],
925	      enum ipset_adt adt, u32 *lineno, u32 flags, bool retried);
926
927static const struct ip_set_type_variant mtype_variant = {
928	.kadt	= mtype_kadt,
929	.uadt	= mtype_uadt,
930	.adt	= {
931		[IPSET_ADD] = mtype_add,
932		[IPSET_DEL] = mtype_del,
933		[IPSET_TEST] = mtype_test,
934	},
935	.destroy = mtype_destroy,
936	.flush	= mtype_flush,
937	.head	= mtype_head,
938	.list	= mtype_list,
939	.resize	= mtype_resize,
940	.same_set = mtype_same_set,
941};
942
943#ifdef IP_SET_EMIT_CREATE
944static int
945TOKEN(HTYPE, _create)(struct ip_set *set, struct nlattr *tb[], u32 flags)
946{
947	u32 hashsize = IPSET_DEFAULT_HASHSIZE, maxelem = IPSET_DEFAULT_MAXELEM;
948	u32 cadt_flags = 0;
949	u8 hbits;
950#ifdef IP_SET_HASH_WITH_NETMASK
951	u8 netmask;
952#endif
953	size_t hsize;
954	struct HTYPE *h;
955
956	if (!(set->family == NFPROTO_IPV4 || set->family == NFPROTO_IPV6))
957		return -IPSET_ERR_INVALID_FAMILY;
958#ifdef IP_SET_HASH_WITH_NETMASK
959	netmask = set->family == NFPROTO_IPV4 ? 32 : 128;
960	pr_debug("Create set %s with family %s\n",
961		 set->name, set->family == NFPROTO_IPV4 ? "inet" : "inet6");
962#endif
963
964	if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_HASHSIZE) ||
965		     !ip_set_optattr_netorder(tb, IPSET_ATTR_MAXELEM) ||
966		     !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
967		     !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS)))
968		return -IPSET_ERR_PROTOCOL;
969
970	if (tb[IPSET_ATTR_HASHSIZE]) {
971		hashsize = ip_set_get_h32(tb[IPSET_ATTR_HASHSIZE]);
972		if (hashsize < IPSET_MIMINAL_HASHSIZE)
973			hashsize = IPSET_MIMINAL_HASHSIZE;
974	}
975
976	if (tb[IPSET_ATTR_MAXELEM])
977		maxelem = ip_set_get_h32(tb[IPSET_ATTR_MAXELEM]);
978
979#ifdef IP_SET_HASH_WITH_NETMASK
980	if (tb[IPSET_ATTR_NETMASK]) {
981		netmask = nla_get_u8(tb[IPSET_ATTR_NETMASK]);
982
983		if ((set->family == NFPROTO_IPV4 && netmask > 32) ||
984		    (set->family == NFPROTO_IPV6 && netmask > 128) ||
985		    netmask == 0)
986			return -IPSET_ERR_INVALID_NETMASK;
987	}
988#endif
989
990	hsize = sizeof(*h);
991#ifdef IP_SET_HASH_WITH_NETS
992	hsize += sizeof(struct net_prefixes) *
993		(set->family == NFPROTO_IPV4 ? 32 : 128);
994#endif
995	h = kzalloc(hsize, GFP_KERNEL);
996	if (!h)
997		return -ENOMEM;
998
999	h->maxelem = maxelem;
1000#ifdef IP_SET_HASH_WITH_NETMASK
1001	h->netmask = netmask;
1002#endif
1003	get_random_bytes(&h->initval, sizeof(h->initval));
1004	h->timeout = IPSET_NO_TIMEOUT;
1005
1006	hbits = htable_bits(hashsize);
1007	hsize = htable_size(hbits);
1008	if (hsize == 0) {
1009		kfree(h);
1010		return -ENOMEM;
1011	}
1012	h->table = ip_set_alloc(hsize);
1013	if (!h->table) {
1014		kfree(h);
1015		return -ENOMEM;
1016	}
1017	h->table->htable_bits = hbits;
1018
1019	set->data = h;
1020	if (set->family ==  NFPROTO_IPV4)
1021		set->variant = &TOKEN(HTYPE, 4_variant);
1022	else
1023		set->variant = &TOKEN(HTYPE, 6_variant);
1024
1025	if (tb[IPSET_ATTR_CADT_FLAGS])
1026		cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
1027	if (cadt_flags & IPSET_FLAG_WITH_COUNTERS) {
1028		set->extensions |= IPSET_EXT_COUNTER;
1029		if (tb[IPSET_ATTR_TIMEOUT]) {
1030			h->timeout =
1031				ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
1032			set->extensions |= IPSET_EXT_TIMEOUT;
1033			if (set->family == NFPROTO_IPV4) {
1034				h->dsize =
1035					sizeof(struct TOKEN(HTYPE, 4ct_elem));
1036				h->offset[IPSET_OFFSET_TIMEOUT] =
1037					offsetof(struct TOKEN(HTYPE, 4ct_elem),
1038						 timeout);
1039				h->offset[IPSET_OFFSET_COUNTER] =
1040					offsetof(struct TOKEN(HTYPE, 4ct_elem),
1041						 counter);
1042				TOKEN(HTYPE, 4_gc_init)(set,
1043					TOKEN(HTYPE, 4_gc));
1044			} else {
1045				h->dsize =
1046					sizeof(struct TOKEN(HTYPE, 6ct_elem));
1047				h->offset[IPSET_OFFSET_TIMEOUT] =
1048					offsetof(struct TOKEN(HTYPE, 6ct_elem),
1049						 timeout);
1050				h->offset[IPSET_OFFSET_COUNTER] =
1051					offsetof(struct TOKEN(HTYPE, 6ct_elem),
1052						 counter);
1053				TOKEN(HTYPE, 6_gc_init)(set,
1054					TOKEN(HTYPE, 6_gc));
1055			}
1056		} else {
1057			if (set->family == NFPROTO_IPV4) {
1058				h->dsize =
1059					sizeof(struct TOKEN(HTYPE, 4c_elem));
1060				h->offset[IPSET_OFFSET_COUNTER] =
1061					offsetof(struct TOKEN(HTYPE, 4c_elem),
1062						 counter);
1063			} else {
1064				h->dsize =
1065					sizeof(struct TOKEN(HTYPE, 6c_elem));
1066				h->offset[IPSET_OFFSET_COUNTER] =
1067					offsetof(struct TOKEN(HTYPE, 6c_elem),
1068						 counter);
1069			}
1070		}
1071	} else if (tb[IPSET_ATTR_TIMEOUT]) {
1072		h->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
1073		set->extensions |= IPSET_EXT_TIMEOUT;
1074		if (set->family == NFPROTO_IPV4) {
1075			h->dsize = sizeof(struct TOKEN(HTYPE, 4t_elem));
1076			h->offset[IPSET_OFFSET_TIMEOUT] =
1077				offsetof(struct TOKEN(HTYPE, 4t_elem),
1078					 timeout);
1079			TOKEN(HTYPE, 4_gc_init)(set, TOKEN(HTYPE, 4_gc));
1080		} else {
1081			h->dsize = sizeof(struct TOKEN(HTYPE, 6t_elem));
1082			h->offset[IPSET_OFFSET_TIMEOUT] =
1083				offsetof(struct TOKEN(HTYPE, 6t_elem),
1084					 timeout);
1085			TOKEN(HTYPE, 6_gc_init)(set, TOKEN(HTYPE, 6_gc));
1086		}
1087	} else {
1088		if (set->family == NFPROTO_IPV4)
1089			h->dsize = sizeof(struct TOKEN(HTYPE, 4_elem));
1090		else
1091			h->dsize = sizeof(struct TOKEN(HTYPE, 6_elem));
1092	}
1093
1094	pr_debug("create %s hashsize %u (%u) maxelem %u: %p(%p)\n",
1095		 set->name, jhash_size(h->table->htable_bits),
1096		 h->table->htable_bits, h->maxelem, set->data, h->table);
1097
1098	return 0;
1099}
1100#endif /* IP_SET_EMIT_CREATE */
1101