dmatest.c revision a9e554957de406d6adc581731f571b8a1503f6b0
1/*
2 * DMA Engine test module
3 *
4 * Copyright (C) 2007 Atmel Corporation
5 * Copyright (C) 2013 Intel Corporation
6 *
7 * This program is free software; you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License version 2 as
9 * published by the Free Software Foundation.
10 */
11#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
12
13#include <linux/delay.h>
14#include <linux/dma-mapping.h>
15#include <linux/dmaengine.h>
16#include <linux/freezer.h>
17#include <linux/init.h>
18#include <linux/kthread.h>
19#include <linux/module.h>
20#include <linux/moduleparam.h>
21#include <linux/random.h>
22#include <linux/slab.h>
23#include <linux/wait.h>
24
25static unsigned int test_buf_size = 16384;
26module_param(test_buf_size, uint, S_IRUGO | S_IWUSR);
27MODULE_PARM_DESC(test_buf_size, "Size of the memcpy test buffer");
28
29static char test_channel[20];
30module_param_string(channel, test_channel, sizeof(test_channel),
31		S_IRUGO | S_IWUSR);
32MODULE_PARM_DESC(channel, "Bus ID of the channel to test (default: any)");
33
34static char test_device[20];
35module_param_string(device, test_device, sizeof(test_device),
36		S_IRUGO | S_IWUSR);
37MODULE_PARM_DESC(device, "Bus ID of the DMA Engine to test (default: any)");
38
39static unsigned int threads_per_chan = 1;
40module_param(threads_per_chan, uint, S_IRUGO | S_IWUSR);
41MODULE_PARM_DESC(threads_per_chan,
42		"Number of threads to start per channel (default: 1)");
43
44static unsigned int max_channels;
45module_param(max_channels, uint, S_IRUGO | S_IWUSR);
46MODULE_PARM_DESC(max_channels,
47		"Maximum number of channels to use (default: all)");
48
49static unsigned int iterations;
50module_param(iterations, uint, S_IRUGO | S_IWUSR);
51MODULE_PARM_DESC(iterations,
52		"Iterations before stopping test (default: infinite)");
53
54static unsigned int xor_sources = 3;
55module_param(xor_sources, uint, S_IRUGO | S_IWUSR);
56MODULE_PARM_DESC(xor_sources,
57		"Number of xor source buffers (default: 3)");
58
59static unsigned int pq_sources = 3;
60module_param(pq_sources, uint, S_IRUGO | S_IWUSR);
61MODULE_PARM_DESC(pq_sources,
62		"Number of p+q source buffers (default: 3)");
63
64static int timeout = 3000;
65module_param(timeout, uint, S_IRUGO | S_IWUSR);
66MODULE_PARM_DESC(timeout, "Transfer Timeout in msec (default: 3000), "
67		 "Pass -1 for infinite timeout");
68
69/**
70 * struct dmatest_params - test parameters.
71 * @buf_size:		size of the memcpy test buffer
72 * @channel:		bus ID of the channel to test
73 * @device:		bus ID of the DMA Engine to test
74 * @threads_per_chan:	number of threads to start per channel
75 * @max_channels:	maximum number of channels to use
76 * @iterations:		iterations before stopping test
77 * @xor_sources:	number of xor source buffers
78 * @pq_sources:		number of p+q source buffers
79 * @timeout:		transfer timeout in msec, -1 for infinite timeout
80 */
81struct dmatest_params {
82	unsigned int	buf_size;
83	char		channel[20];
84	char		device[20];
85	unsigned int	threads_per_chan;
86	unsigned int	max_channels;
87	unsigned int	iterations;
88	unsigned int	xor_sources;
89	unsigned int	pq_sources;
90	int		timeout;
91};
92
93/**
94 * struct dmatest_info - test information.
95 * @params:		test parameters
96 * @lock:		access protection to the fields of this structure
97 */
98static struct dmatest_info {
99	/* Test parameters */
100	struct dmatest_params	params;
101
102	/* Internal state */
103	struct list_head	channels;
104	unsigned int		nr_channels;
105	struct mutex		lock;
106	bool			did_init;
107} test_info = {
108	.channels = LIST_HEAD_INIT(test_info.channels),
109	.lock = __MUTEX_INITIALIZER(test_info.lock),
110};
111
112static int dmatest_run_set(const char *val, const struct kernel_param *kp);
113static int dmatest_run_get(char *val, const struct kernel_param *kp);
114static struct kernel_param_ops run_ops = {
115	.set = dmatest_run_set,
116	.get = dmatest_run_get,
117};
118static bool dmatest_run;
119module_param_cb(run, &run_ops, &dmatest_run, S_IRUGO | S_IWUSR);
120MODULE_PARM_DESC(run, "Run the test (default: false)");
121
122/* Maximum amount of mismatched bytes in buffer to print */
123#define MAX_ERROR_COUNT		32
124
125/*
126 * Initialization patterns. All bytes in the source buffer has bit 7
127 * set, all bytes in the destination buffer has bit 7 cleared.
128 *
129 * Bit 6 is set for all bytes which are to be copied by the DMA
130 * engine. Bit 5 is set for all bytes which are to be overwritten by
131 * the DMA engine.
132 *
133 * The remaining bits are the inverse of a counter which increments by
134 * one for each byte address.
135 */
136#define PATTERN_SRC		0x80
137#define PATTERN_DST		0x00
138#define PATTERN_COPY		0x40
139#define PATTERN_OVERWRITE	0x20
140#define PATTERN_COUNT_MASK	0x1f
141
142struct dmatest_thread {
143	struct list_head	node;
144	struct dmatest_info	*info;
145	struct task_struct	*task;
146	struct dma_chan		*chan;
147	u8			**srcs;
148	u8			**dsts;
149	enum dma_transaction_type type;
150	bool			done;
151};
152
153struct dmatest_chan {
154	struct list_head	node;
155	struct dma_chan		*chan;
156	struct list_head	threads;
157};
158
159static bool dmatest_match_channel(struct dmatest_params *params,
160		struct dma_chan *chan)
161{
162	if (params->channel[0] == '\0')
163		return true;
164	return strcmp(dma_chan_name(chan), params->channel) == 0;
165}
166
167static bool dmatest_match_device(struct dmatest_params *params,
168		struct dma_device *device)
169{
170	if (params->device[0] == '\0')
171		return true;
172	return strcmp(dev_name(device->dev), params->device) == 0;
173}
174
175static unsigned long dmatest_random(void)
176{
177	unsigned long buf;
178
179	get_random_bytes(&buf, sizeof(buf));
180	return buf;
181}
182
183static void dmatest_init_srcs(u8 **bufs, unsigned int start, unsigned int len,
184		unsigned int buf_size)
185{
186	unsigned int i;
187	u8 *buf;
188
189	for (; (buf = *bufs); bufs++) {
190		for (i = 0; i < start; i++)
191			buf[i] = PATTERN_SRC | (~i & PATTERN_COUNT_MASK);
192		for ( ; i < start + len; i++)
193			buf[i] = PATTERN_SRC | PATTERN_COPY
194				| (~i & PATTERN_COUNT_MASK);
195		for ( ; i < buf_size; i++)
196			buf[i] = PATTERN_SRC | (~i & PATTERN_COUNT_MASK);
197		buf++;
198	}
199}
200
201static void dmatest_init_dsts(u8 **bufs, unsigned int start, unsigned int len,
202		unsigned int buf_size)
203{
204	unsigned int i;
205	u8 *buf;
206
207	for (; (buf = *bufs); bufs++) {
208		for (i = 0; i < start; i++)
209			buf[i] = PATTERN_DST | (~i & PATTERN_COUNT_MASK);
210		for ( ; i < start + len; i++)
211			buf[i] = PATTERN_DST | PATTERN_OVERWRITE
212				| (~i & PATTERN_COUNT_MASK);
213		for ( ; i < buf_size; i++)
214			buf[i] = PATTERN_DST | (~i & PATTERN_COUNT_MASK);
215	}
216}
217
218static void dmatest_mismatch(u8 actual, u8 pattern, unsigned int index,
219		unsigned int counter, bool is_srcbuf)
220{
221	u8		diff = actual ^ pattern;
222	u8		expected = pattern | (~counter & PATTERN_COUNT_MASK);
223	const char	*thread_name = current->comm;
224
225	if (is_srcbuf)
226		pr_warn("%s: srcbuf[0x%x] overwritten! Expected %02x, got %02x\n",
227			thread_name, index, expected, actual);
228	else if ((pattern & PATTERN_COPY)
229			&& (diff & (PATTERN_COPY | PATTERN_OVERWRITE)))
230		pr_warn("%s: dstbuf[0x%x] not copied! Expected %02x, got %02x\n",
231			thread_name, index, expected, actual);
232	else if (diff & PATTERN_SRC)
233		pr_warn("%s: dstbuf[0x%x] was copied! Expected %02x, got %02x\n",
234			thread_name, index, expected, actual);
235	else
236		pr_warn("%s: dstbuf[0x%x] mismatch! Expected %02x, got %02x\n",
237			thread_name, index, expected, actual);
238}
239
240static unsigned int dmatest_verify(u8 **bufs, unsigned int start,
241		unsigned int end, unsigned int counter, u8 pattern,
242		bool is_srcbuf)
243{
244	unsigned int i;
245	unsigned int error_count = 0;
246	u8 actual;
247	u8 expected;
248	u8 *buf;
249	unsigned int counter_orig = counter;
250
251	for (; (buf = *bufs); bufs++) {
252		counter = counter_orig;
253		for (i = start; i < end; i++) {
254			actual = buf[i];
255			expected = pattern | (~counter & PATTERN_COUNT_MASK);
256			if (actual != expected) {
257				if (error_count < MAX_ERROR_COUNT)
258					dmatest_mismatch(actual, pattern, i,
259							 counter, is_srcbuf);
260				error_count++;
261			}
262			counter++;
263		}
264	}
265
266	if (error_count > MAX_ERROR_COUNT)
267		pr_warn("%s: %u errors suppressed\n",
268			current->comm, error_count - MAX_ERROR_COUNT);
269
270	return error_count;
271}
272
273/* poor man's completion - we want to use wait_event_freezable() on it */
274struct dmatest_done {
275	bool			done;
276	wait_queue_head_t	*wait;
277};
278
279static void dmatest_callback(void *arg)
280{
281	struct dmatest_done *done = arg;
282
283	done->done = true;
284	wake_up_all(done->wait);
285}
286
287static inline void unmap_src(struct device *dev, dma_addr_t *addr, size_t len,
288			     unsigned int count)
289{
290	while (count--)
291		dma_unmap_single(dev, addr[count], len, DMA_TO_DEVICE);
292}
293
294static inline void unmap_dst(struct device *dev, dma_addr_t *addr, size_t len,
295			     unsigned int count)
296{
297	while (count--)
298		dma_unmap_single(dev, addr[count], len, DMA_BIDIRECTIONAL);
299}
300
301static unsigned int min_odd(unsigned int x, unsigned int y)
302{
303	unsigned int val = min(x, y);
304
305	return val % 2 ? val : val - 1;
306}
307
308static void result(const char *err, unsigned int n, unsigned int src_off,
309		   unsigned int dst_off, unsigned int len, unsigned long data)
310{
311	pr_info("%s: result #%u: '%s' with src_off=0x%x ""dst_off=0x%x len=0x%x (%lu)",
312		current->comm, n, err, src_off, dst_off, len, data);
313}
314
315static void dbg_result(const char *err, unsigned int n, unsigned int src_off,
316		       unsigned int dst_off, unsigned int len,
317		       unsigned long data)
318{
319	pr_debug("%s: result #%u: '%s' with src_off=0x%x ""dst_off=0x%x len=0x%x (%lu)",
320		 current->comm, n, err, src_off, dst_off, len, data);
321}
322
323/*
324 * This function repeatedly tests DMA transfers of various lengths and
325 * offsets for a given operation type until it is told to exit by
326 * kthread_stop(). There may be multiple threads running this function
327 * in parallel for a single channel, and there may be multiple channels
328 * being tested in parallel.
329 *
330 * Before each test, the source and destination buffer is initialized
331 * with a known pattern. This pattern is different depending on
332 * whether it's in an area which is supposed to be copied or
333 * overwritten, and different in the source and destination buffers.
334 * So if the DMA engine doesn't copy exactly what we tell it to copy,
335 * we'll notice.
336 */
337static int dmatest_func(void *data)
338{
339	DECLARE_WAIT_QUEUE_HEAD_ONSTACK(done_wait);
340	struct dmatest_thread	*thread = data;
341	struct dmatest_done	done = { .wait = &done_wait };
342	struct dmatest_info	*info;
343	struct dmatest_params	*params;
344	struct dma_chan		*chan;
345	struct dma_device	*dev;
346	unsigned int		src_off, dst_off, len;
347	unsigned int		error_count;
348	unsigned int		failed_tests = 0;
349	unsigned int		total_tests = 0;
350	dma_cookie_t		cookie;
351	enum dma_status		status;
352	enum dma_ctrl_flags 	flags;
353	u8			*pq_coefs = NULL;
354	int			ret;
355	int			src_cnt;
356	int			dst_cnt;
357	int			i;
358
359	set_freezable();
360
361	ret = -ENOMEM;
362
363	smp_rmb();
364	info = thread->info;
365	params = &info->params;
366	chan = thread->chan;
367	dev = chan->device;
368	if (thread->type == DMA_MEMCPY)
369		src_cnt = dst_cnt = 1;
370	else if (thread->type == DMA_XOR) {
371		/* force odd to ensure dst = src */
372		src_cnt = min_odd(params->xor_sources | 1, dev->max_xor);
373		dst_cnt = 1;
374	} else if (thread->type == DMA_PQ) {
375		/* force odd to ensure dst = src */
376		src_cnt = min_odd(params->pq_sources | 1, dma_maxpq(dev, 0));
377		dst_cnt = 2;
378
379		pq_coefs = kmalloc(params->pq_sources+1, GFP_KERNEL);
380		if (!pq_coefs)
381			goto err_thread_type;
382
383		for (i = 0; i < src_cnt; i++)
384			pq_coefs[i] = 1;
385	} else
386		goto err_thread_type;
387
388	thread->srcs = kcalloc(src_cnt+1, sizeof(u8 *), GFP_KERNEL);
389	if (!thread->srcs)
390		goto err_srcs;
391	for (i = 0; i < src_cnt; i++) {
392		thread->srcs[i] = kmalloc(params->buf_size, GFP_KERNEL);
393		if (!thread->srcs[i])
394			goto err_srcbuf;
395	}
396	thread->srcs[i] = NULL;
397
398	thread->dsts = kcalloc(dst_cnt+1, sizeof(u8 *), GFP_KERNEL);
399	if (!thread->dsts)
400		goto err_dsts;
401	for (i = 0; i < dst_cnt; i++) {
402		thread->dsts[i] = kmalloc(params->buf_size, GFP_KERNEL);
403		if (!thread->dsts[i])
404			goto err_dstbuf;
405	}
406	thread->dsts[i] = NULL;
407
408	set_user_nice(current, 10);
409
410	/*
411	 * src and dst buffers are freed by ourselves below
412	 */
413	flags = DMA_CTRL_ACK | DMA_PREP_INTERRUPT;
414
415	while (!kthread_should_stop()
416	       && !(params->iterations && total_tests >= params->iterations)) {
417		struct dma_async_tx_descriptor *tx = NULL;
418		dma_addr_t dma_srcs[src_cnt];
419		dma_addr_t dma_dsts[dst_cnt];
420		u8 align = 0;
421
422		total_tests++;
423
424		/* honor alignment restrictions */
425		if (thread->type == DMA_MEMCPY)
426			align = dev->copy_align;
427		else if (thread->type == DMA_XOR)
428			align = dev->xor_align;
429		else if (thread->type == DMA_PQ)
430			align = dev->pq_align;
431
432		if (1 << align > params->buf_size) {
433			pr_err("%u-byte buffer too small for %d-byte alignment\n",
434			       params->buf_size, 1 << align);
435			break;
436		}
437
438		len = dmatest_random() % params->buf_size + 1;
439		len = (len >> align) << align;
440		if (!len)
441			len = 1 << align;
442		src_off = dmatest_random() % (params->buf_size - len + 1);
443		dst_off = dmatest_random() % (params->buf_size - len + 1);
444
445		src_off = (src_off >> align) << align;
446		dst_off = (dst_off >> align) << align;
447
448		dmatest_init_srcs(thread->srcs, src_off, len, params->buf_size);
449		dmatest_init_dsts(thread->dsts, dst_off, len, params->buf_size);
450
451		for (i = 0; i < src_cnt; i++) {
452			u8 *buf = thread->srcs[i] + src_off;
453
454			dma_srcs[i] = dma_map_single(dev->dev, buf, len,
455						     DMA_TO_DEVICE);
456			ret = dma_mapping_error(dev->dev, dma_srcs[i]);
457			if (ret) {
458				unmap_src(dev->dev, dma_srcs, len, i);
459				result("src mapping error", total_tests,
460				       src_off, dst_off, len, ret);
461				failed_tests++;
462				continue;
463			}
464		}
465		/* map with DMA_BIDIRECTIONAL to force writeback/invalidate */
466		for (i = 0; i < dst_cnt; i++) {
467			dma_dsts[i] = dma_map_single(dev->dev, thread->dsts[i],
468						     params->buf_size,
469						     DMA_BIDIRECTIONAL);
470			ret = dma_mapping_error(dev->dev, dma_dsts[i]);
471			if (ret) {
472				unmap_src(dev->dev, dma_srcs, len, src_cnt);
473				unmap_dst(dev->dev, dma_dsts, params->buf_size,
474					  i);
475				result("dst mapping error", total_tests,
476				       src_off, dst_off, len, ret);
477				failed_tests++;
478				continue;
479			}
480		}
481
482		if (thread->type == DMA_MEMCPY)
483			tx = dev->device_prep_dma_memcpy(chan,
484							 dma_dsts[0] + dst_off,
485							 dma_srcs[0], len,
486							 flags);
487		else if (thread->type == DMA_XOR)
488			tx = dev->device_prep_dma_xor(chan,
489						      dma_dsts[0] + dst_off,
490						      dma_srcs, src_cnt,
491						      len, flags);
492		else if (thread->type == DMA_PQ) {
493			dma_addr_t dma_pq[dst_cnt];
494
495			for (i = 0; i < dst_cnt; i++)
496				dma_pq[i] = dma_dsts[i] + dst_off;
497			tx = dev->device_prep_dma_pq(chan, dma_pq, dma_srcs,
498						     src_cnt, pq_coefs,
499						     len, flags);
500		}
501
502		if (!tx) {
503			unmap_src(dev->dev, dma_srcs, len, src_cnt);
504			unmap_dst(dev->dev, dma_dsts, params->buf_size,
505				  dst_cnt);
506			result("prep error", total_tests, src_off,
507			       dst_off, len, ret);
508			msleep(100);
509			failed_tests++;
510			continue;
511		}
512
513		done.done = false;
514		tx->callback = dmatest_callback;
515		tx->callback_param = &done;
516		cookie = tx->tx_submit(tx);
517
518		if (dma_submit_error(cookie)) {
519			result("submit error", total_tests, src_off,
520			       dst_off, len, ret);
521			msleep(100);
522			failed_tests++;
523			continue;
524		}
525		dma_async_issue_pending(chan);
526
527		wait_event_freezable_timeout(done_wait, done.done,
528					     msecs_to_jiffies(params->timeout));
529
530		status = dma_async_is_tx_complete(chan, cookie, NULL, NULL);
531
532		if (!done.done) {
533			/*
534			 * We're leaving the timed out dma operation with
535			 * dangling pointer to done_wait.  To make this
536			 * correct, we'll need to allocate wait_done for
537			 * each test iteration and perform "who's gonna
538			 * free it this time?" dancing.  For now, just
539			 * leave it dangling.
540			 */
541			result("test timed out", total_tests, src_off, dst_off,
542			       len, 0);
543			failed_tests++;
544			continue;
545		} else if (status != DMA_SUCCESS) {
546			result(status == DMA_ERROR ?
547			       "completion error status" :
548			       "completion busy status", total_tests, src_off,
549			       dst_off, len, ret);
550			failed_tests++;
551			continue;
552		}
553
554		/* Unmap by myself */
555		unmap_src(dev->dev, dma_srcs, len, src_cnt);
556		unmap_dst(dev->dev, dma_dsts, params->buf_size, dst_cnt);
557
558		error_count = 0;
559
560		pr_debug("%s: verifying source buffer...\n", current->comm);
561		error_count += dmatest_verify(thread->srcs, 0, src_off,
562				0, PATTERN_SRC, true);
563		error_count += dmatest_verify(thread->srcs, src_off,
564				src_off + len, src_off,
565				PATTERN_SRC | PATTERN_COPY, true);
566		error_count += dmatest_verify(thread->srcs, src_off + len,
567				params->buf_size, src_off + len,
568				PATTERN_SRC, true);
569
570		pr_debug("%s: verifying dest buffer...\n", current->comm);
571		error_count += dmatest_verify(thread->dsts, 0, dst_off,
572				0, PATTERN_DST, false);
573		error_count += dmatest_verify(thread->dsts, dst_off,
574				dst_off + len, src_off,
575				PATTERN_SRC | PATTERN_COPY, false);
576		error_count += dmatest_verify(thread->dsts, dst_off + len,
577				params->buf_size, dst_off + len,
578				PATTERN_DST, false);
579
580		if (error_count) {
581			result("data error", total_tests, src_off, dst_off,
582			       len, error_count);
583			failed_tests++;
584		} else {
585			dbg_result("test passed", total_tests, src_off, dst_off,
586				   len, 0);
587		}
588	}
589
590	ret = 0;
591	for (i = 0; thread->dsts[i]; i++)
592		kfree(thread->dsts[i]);
593err_dstbuf:
594	kfree(thread->dsts);
595err_dsts:
596	for (i = 0; thread->srcs[i]; i++)
597		kfree(thread->srcs[i]);
598err_srcbuf:
599	kfree(thread->srcs);
600err_srcs:
601	kfree(pq_coefs);
602err_thread_type:
603	pr_info("%s: terminating after %u tests, %u failures (status %d)\n",
604		current->comm, total_tests, failed_tests, ret);
605
606	/* terminate all transfers on specified channels */
607	if (ret)
608		dmaengine_terminate_all(chan);
609
610	thread->done = true;
611
612	if (params->iterations > 0)
613		while (!kthread_should_stop()) {
614			DECLARE_WAIT_QUEUE_HEAD_ONSTACK(wait_dmatest_exit);
615			interruptible_sleep_on(&wait_dmatest_exit);
616		}
617
618	return ret;
619}
620
621static void dmatest_cleanup_channel(struct dmatest_chan *dtc)
622{
623	struct dmatest_thread	*thread;
624	struct dmatest_thread	*_thread;
625	int			ret;
626
627	list_for_each_entry_safe(thread, _thread, &dtc->threads, node) {
628		ret = kthread_stop(thread->task);
629		pr_debug("thread %s exited with status %d\n",
630			 thread->task->comm, ret);
631		list_del(&thread->node);
632		kfree(thread);
633	}
634
635	/* terminate all transfers on specified channels */
636	dmaengine_terminate_all(dtc->chan);
637
638	kfree(dtc);
639}
640
641static int dmatest_add_threads(struct dmatest_info *info,
642		struct dmatest_chan *dtc, enum dma_transaction_type type)
643{
644	struct dmatest_params *params = &info->params;
645	struct dmatest_thread *thread;
646	struct dma_chan *chan = dtc->chan;
647	char *op;
648	unsigned int i;
649
650	if (type == DMA_MEMCPY)
651		op = "copy";
652	else if (type == DMA_XOR)
653		op = "xor";
654	else if (type == DMA_PQ)
655		op = "pq";
656	else
657		return -EINVAL;
658
659	for (i = 0; i < params->threads_per_chan; i++) {
660		thread = kzalloc(sizeof(struct dmatest_thread), GFP_KERNEL);
661		if (!thread) {
662			pr_warn("No memory for %s-%s%u\n",
663				dma_chan_name(chan), op, i);
664			break;
665		}
666		thread->info = info;
667		thread->chan = dtc->chan;
668		thread->type = type;
669		smp_wmb();
670		thread->task = kthread_run(dmatest_func, thread, "%s-%s%u",
671				dma_chan_name(chan), op, i);
672		if (IS_ERR(thread->task)) {
673			pr_warn("Failed to run thread %s-%s%u\n",
674				dma_chan_name(chan), op, i);
675			kfree(thread);
676			break;
677		}
678
679		/* srcbuf and dstbuf are allocated by the thread itself */
680
681		list_add_tail(&thread->node, &dtc->threads);
682	}
683
684	return i;
685}
686
687static int dmatest_add_channel(struct dmatest_info *info,
688		struct dma_chan *chan)
689{
690	struct dmatest_chan	*dtc;
691	struct dma_device	*dma_dev = chan->device;
692	unsigned int		thread_count = 0;
693	int cnt;
694
695	dtc = kmalloc(sizeof(struct dmatest_chan), GFP_KERNEL);
696	if (!dtc) {
697		pr_warn("No memory for %s\n", dma_chan_name(chan));
698		return -ENOMEM;
699	}
700
701	dtc->chan = chan;
702	INIT_LIST_HEAD(&dtc->threads);
703
704	if (dma_has_cap(DMA_MEMCPY, dma_dev->cap_mask)) {
705		cnt = dmatest_add_threads(info, dtc, DMA_MEMCPY);
706		thread_count += cnt > 0 ? cnt : 0;
707	}
708	if (dma_has_cap(DMA_XOR, dma_dev->cap_mask)) {
709		cnt = dmatest_add_threads(info, dtc, DMA_XOR);
710		thread_count += cnt > 0 ? cnt : 0;
711	}
712	if (dma_has_cap(DMA_PQ, dma_dev->cap_mask)) {
713		cnt = dmatest_add_threads(info, dtc, DMA_PQ);
714		thread_count += cnt > 0 ? cnt : 0;
715	}
716
717	pr_info("Started %u threads using %s\n",
718		thread_count, dma_chan_name(chan));
719
720	list_add_tail(&dtc->node, &info->channels);
721	info->nr_channels++;
722
723	return 0;
724}
725
726static bool filter(struct dma_chan *chan, void *param)
727{
728	struct dmatest_params *params = param;
729
730	if (!dmatest_match_channel(params, chan) ||
731	    !dmatest_match_device(params, chan->device))
732		return false;
733	else
734		return true;
735}
736
737static void request_channels(struct dmatest_info *info,
738			     enum dma_transaction_type type)
739{
740	dma_cap_mask_t mask;
741
742	dma_cap_zero(mask);
743	dma_cap_set(type, mask);
744	for (;;) {
745		struct dmatest_params *params = &info->params;
746		struct dma_chan *chan;
747
748		chan = dma_request_channel(mask, filter, params);
749		if (chan) {
750			if (dmatest_add_channel(info, chan)) {
751				dma_release_channel(chan);
752				break; /* add_channel failed, punt */
753			}
754		} else
755			break; /* no more channels available */
756		if (params->max_channels &&
757		    info->nr_channels >= params->max_channels)
758			break; /* we have all we need */
759	}
760}
761
762static void run_threaded_test(struct dmatest_info *info)
763{
764	struct dmatest_params *params = &info->params;
765
766	/* Copy test parameters */
767	params->buf_size = test_buf_size;
768	strlcpy(params->channel, strim(test_channel), sizeof(params->channel));
769	strlcpy(params->device, strim(test_device), sizeof(params->device));
770	params->threads_per_chan = threads_per_chan;
771	params->max_channels = max_channels;
772	params->iterations = iterations;
773	params->xor_sources = xor_sources;
774	params->pq_sources = pq_sources;
775	params->timeout = timeout;
776
777	request_channels(info, DMA_MEMCPY);
778	request_channels(info, DMA_XOR);
779	request_channels(info, DMA_PQ);
780}
781
782static void stop_threaded_test(struct dmatest_info *info)
783{
784	struct dmatest_chan *dtc, *_dtc;
785	struct dma_chan *chan;
786
787	list_for_each_entry_safe(dtc, _dtc, &info->channels, node) {
788		list_del(&dtc->node);
789		chan = dtc->chan;
790		dmatest_cleanup_channel(dtc);
791		pr_debug("dropped channel %s\n", dma_chan_name(chan));
792		dma_release_channel(chan);
793	}
794
795	info->nr_channels = 0;
796}
797
798static void restart_threaded_test(struct dmatest_info *info, bool run)
799{
800	/* we might be called early to set run=, defer running until all
801	 * parameters have been evaluated
802	 */
803	if (!info->did_init)
804		return;
805
806	/* Stop any running test first */
807	stop_threaded_test(info);
808
809	/* Run test with new parameters */
810	run_threaded_test(info);
811}
812
813static bool is_threaded_test_run(struct dmatest_info *info)
814{
815	struct dmatest_chan *dtc;
816
817	list_for_each_entry(dtc, &info->channels, node) {
818		struct dmatest_thread *thread;
819
820		list_for_each_entry(thread, &dtc->threads, node) {
821			if (!thread->done)
822				return true;
823		}
824	}
825
826	return false;
827}
828
829static int dmatest_run_get(char *val, const struct kernel_param *kp)
830{
831	struct dmatest_info *info = &test_info;
832
833	mutex_lock(&info->lock);
834	if (is_threaded_test_run(info)) {
835		dmatest_run = true;
836	} else {
837		stop_threaded_test(info);
838		dmatest_run = false;
839	}
840	mutex_unlock(&info->lock);
841
842	return param_get_bool(val, kp);
843}
844
845static int dmatest_run_set(const char *val, const struct kernel_param *kp)
846{
847	struct dmatest_info *info = &test_info;
848	int ret;
849
850	mutex_lock(&info->lock);
851	ret = param_set_bool(val, kp);
852	if (ret) {
853		mutex_unlock(&info->lock);
854		return ret;
855	}
856
857	if (is_threaded_test_run(info))
858		ret = -EBUSY;
859	else if (dmatest_run)
860		restart_threaded_test(info, dmatest_run);
861
862	mutex_unlock(&info->lock);
863
864	return ret;
865}
866
867static int __init dmatest_init(void)
868{
869	struct dmatest_info *info = &test_info;
870
871	if (dmatest_run) {
872		mutex_lock(&info->lock);
873		run_threaded_test(info);
874		mutex_unlock(&info->lock);
875	}
876
877	/* module parameters are stable, inittime tests are started,
878	 * let userspace take over 'run' control
879	 */
880	info->did_init = true;
881
882	return 0;
883}
884/* when compiled-in wait for drivers to load first */
885late_initcall(dmatest_init);
886
887static void __exit dmatest_exit(void)
888{
889	struct dmatest_info *info = &test_info;
890
891	mutex_lock(&info->lock);
892	stop_threaded_test(info);
893	mutex_unlock(&info->lock);
894}
895module_exit(dmatest_exit);
896
897MODULE_AUTHOR("Haavard Skinnemoen (Atmel)");
898MODULE_LICENSE("GPL v2");
899