1/*
2 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
3 * Author: Joerg Roedel <joerg.roedel@amd.com>
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 as published
7 * by the Free Software Foundation.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17 */
18
19#include <linux/mmu_notifier.h>
20#include <linux/amd-iommu.h>
21#include <linux/mm_types.h>
22#include <linux/profile.h>
23#include <linux/module.h>
24#include <linux/sched.h>
25#include <linux/iommu.h>
26#include <linux/wait.h>
27#include <linux/pci.h>
28#include <linux/gfp.h>
29
30#include "amd_iommu_types.h"
31#include "amd_iommu_proto.h"
32
33MODULE_LICENSE("GPL v2");
34MODULE_AUTHOR("Joerg Roedel <joerg.roedel@amd.com>");
35
36#define MAX_DEVICES		0x10000
37#define PRI_QUEUE_SIZE		512
38
39struct pri_queue {
40	atomic_t inflight;
41	bool finish;
42	int status;
43};
44
45struct pasid_state {
46	struct list_head list;			/* For global state-list */
47	atomic_t count;				/* Reference count */
48	struct task_struct *task;		/* Task bound to this PASID */
49	struct mm_struct *mm;			/* mm_struct for the faults */
50	struct mmu_notifier mn;                 /* mmu_otifier handle */
51	struct pri_queue pri[PRI_QUEUE_SIZE];	/* PRI tag states */
52	struct device_state *device_state;	/* Link to our device_state */
53	int pasid;				/* PASID index */
54	spinlock_t lock;			/* Protect pri_queues */
55	wait_queue_head_t wq;			/* To wait for count == 0 */
56};
57
58struct device_state {
59	atomic_t count;
60	struct pci_dev *pdev;
61	struct pasid_state **states;
62	struct iommu_domain *domain;
63	int pasid_levels;
64	int max_pasids;
65	amd_iommu_invalid_ppr_cb inv_ppr_cb;
66	amd_iommu_invalidate_ctx inv_ctx_cb;
67	spinlock_t lock;
68	wait_queue_head_t wq;
69};
70
71struct fault {
72	struct work_struct work;
73	struct device_state *dev_state;
74	struct pasid_state *state;
75	struct mm_struct *mm;
76	u64 address;
77	u16 devid;
78	u16 pasid;
79	u16 tag;
80	u16 finish;
81	u16 flags;
82};
83
84struct device_state **state_table;
85static spinlock_t state_lock;
86
87/* List and lock for all pasid_states */
88static LIST_HEAD(pasid_state_list);
89static DEFINE_SPINLOCK(ps_lock);
90
91static struct workqueue_struct *iommu_wq;
92
93/*
94 * Empty page table - Used between
95 * mmu_notifier_invalidate_range_start and
96 * mmu_notifier_invalidate_range_end
97 */
98static u64 *empty_page_table;
99
100static void free_pasid_states(struct device_state *dev_state);
101static void unbind_pasid(struct device_state *dev_state, int pasid);
102static int task_exit(struct notifier_block *nb, unsigned long e, void *data);
103
104static u16 device_id(struct pci_dev *pdev)
105{
106	u16 devid;
107
108	devid = pdev->bus->number;
109	devid = (devid << 8) | pdev->devfn;
110
111	return devid;
112}
113
114static struct device_state *get_device_state(u16 devid)
115{
116	struct device_state *dev_state;
117	unsigned long flags;
118
119	spin_lock_irqsave(&state_lock, flags);
120	dev_state = state_table[devid];
121	if (dev_state != NULL)
122		atomic_inc(&dev_state->count);
123	spin_unlock_irqrestore(&state_lock, flags);
124
125	return dev_state;
126}
127
128static void free_device_state(struct device_state *dev_state)
129{
130	/*
131	 * First detach device from domain - No more PRI requests will arrive
132	 * from that device after it is unbound from the IOMMUv2 domain.
133	 */
134	iommu_detach_device(dev_state->domain, &dev_state->pdev->dev);
135
136	/* Everything is down now, free the IOMMUv2 domain */
137	iommu_domain_free(dev_state->domain);
138
139	/* Finally get rid of the device-state */
140	kfree(dev_state);
141}
142
143static void put_device_state(struct device_state *dev_state)
144{
145	if (atomic_dec_and_test(&dev_state->count))
146		wake_up(&dev_state->wq);
147}
148
149static void put_device_state_wait(struct device_state *dev_state)
150{
151	DEFINE_WAIT(wait);
152
153	prepare_to_wait(&dev_state->wq, &wait, TASK_UNINTERRUPTIBLE);
154	if (!atomic_dec_and_test(&dev_state->count))
155		schedule();
156	finish_wait(&dev_state->wq, &wait);
157
158	free_device_state(dev_state);
159}
160
161static struct notifier_block profile_nb = {
162	.notifier_call = task_exit,
163};
164
165static void link_pasid_state(struct pasid_state *pasid_state)
166{
167	spin_lock(&ps_lock);
168	list_add_tail(&pasid_state->list, &pasid_state_list);
169	spin_unlock(&ps_lock);
170}
171
172static void __unlink_pasid_state(struct pasid_state *pasid_state)
173{
174	list_del(&pasid_state->list);
175}
176
177static void unlink_pasid_state(struct pasid_state *pasid_state)
178{
179	spin_lock(&ps_lock);
180	__unlink_pasid_state(pasid_state);
181	spin_unlock(&ps_lock);
182}
183
184/* Must be called under dev_state->lock */
185static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
186						  int pasid, bool alloc)
187{
188	struct pasid_state **root, **ptr;
189	int level, index;
190
191	level = dev_state->pasid_levels;
192	root  = dev_state->states;
193
194	while (true) {
195
196		index = (pasid >> (9 * level)) & 0x1ff;
197		ptr   = &root[index];
198
199		if (level == 0)
200			break;
201
202		if (*ptr == NULL) {
203			if (!alloc)
204				return NULL;
205
206			*ptr = (void *)get_zeroed_page(GFP_ATOMIC);
207			if (*ptr == NULL)
208				return NULL;
209		}
210
211		root   = (struct pasid_state **)*ptr;
212		level -= 1;
213	}
214
215	return ptr;
216}
217
218static int set_pasid_state(struct device_state *dev_state,
219			   struct pasid_state *pasid_state,
220			   int pasid)
221{
222	struct pasid_state **ptr;
223	unsigned long flags;
224	int ret;
225
226	spin_lock_irqsave(&dev_state->lock, flags);
227	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
228
229	ret = -ENOMEM;
230	if (ptr == NULL)
231		goto out_unlock;
232
233	ret = -ENOMEM;
234	if (*ptr != NULL)
235		goto out_unlock;
236
237	*ptr = pasid_state;
238
239	ret = 0;
240
241out_unlock:
242	spin_unlock_irqrestore(&dev_state->lock, flags);
243
244	return ret;
245}
246
247static void clear_pasid_state(struct device_state *dev_state, int pasid)
248{
249	struct pasid_state **ptr;
250	unsigned long flags;
251
252	spin_lock_irqsave(&dev_state->lock, flags);
253	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
254
255	if (ptr == NULL)
256		goto out_unlock;
257
258	*ptr = NULL;
259
260out_unlock:
261	spin_unlock_irqrestore(&dev_state->lock, flags);
262}
263
264static struct pasid_state *get_pasid_state(struct device_state *dev_state,
265					   int pasid)
266{
267	struct pasid_state **ptr, *ret = NULL;
268	unsigned long flags;
269
270	spin_lock_irqsave(&dev_state->lock, flags);
271	ptr = __get_pasid_state_ptr(dev_state, pasid, false);
272
273	if (ptr == NULL)
274		goto out_unlock;
275
276	ret = *ptr;
277	if (ret)
278		atomic_inc(&ret->count);
279
280out_unlock:
281	spin_unlock_irqrestore(&dev_state->lock, flags);
282
283	return ret;
284}
285
286static void free_pasid_state(struct pasid_state *pasid_state)
287{
288	kfree(pasid_state);
289}
290
291static void put_pasid_state(struct pasid_state *pasid_state)
292{
293	if (atomic_dec_and_test(&pasid_state->count)) {
294		put_device_state(pasid_state->device_state);
295		wake_up(&pasid_state->wq);
296	}
297}
298
299static void put_pasid_state_wait(struct pasid_state *pasid_state)
300{
301	DEFINE_WAIT(wait);
302
303	prepare_to_wait(&pasid_state->wq, &wait, TASK_UNINTERRUPTIBLE);
304
305	if (atomic_dec_and_test(&pasid_state->count))
306		put_device_state(pasid_state->device_state);
307	else
308		schedule();
309
310	finish_wait(&pasid_state->wq, &wait);
311	mmput(pasid_state->mm);
312	free_pasid_state(pasid_state);
313}
314
315static void __unbind_pasid(struct pasid_state *pasid_state)
316{
317	struct iommu_domain *domain;
318
319	domain = pasid_state->device_state->domain;
320
321	amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
322	clear_pasid_state(pasid_state->device_state, pasid_state->pasid);
323
324	/* Make sure no more pending faults are in the queue */
325	flush_workqueue(iommu_wq);
326
327	mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
328
329	put_pasid_state(pasid_state); /* Reference taken in bind() function */
330}
331
332static void unbind_pasid(struct device_state *dev_state, int pasid)
333{
334	struct pasid_state *pasid_state;
335
336	pasid_state = get_pasid_state(dev_state, pasid);
337	if (pasid_state == NULL)
338		return;
339
340	unlink_pasid_state(pasid_state);
341	__unbind_pasid(pasid_state);
342	put_pasid_state_wait(pasid_state); /* Reference taken in this function */
343}
344
345static void free_pasid_states_level1(struct pasid_state **tbl)
346{
347	int i;
348
349	for (i = 0; i < 512; ++i) {
350		if (tbl[i] == NULL)
351			continue;
352
353		free_page((unsigned long)tbl[i]);
354	}
355}
356
357static void free_pasid_states_level2(struct pasid_state **tbl)
358{
359	struct pasid_state **ptr;
360	int i;
361
362	for (i = 0; i < 512; ++i) {
363		if (tbl[i] == NULL)
364			continue;
365
366		ptr = (struct pasid_state **)tbl[i];
367		free_pasid_states_level1(ptr);
368	}
369}
370
371static void free_pasid_states(struct device_state *dev_state)
372{
373	struct pasid_state *pasid_state;
374	int i;
375
376	for (i = 0; i < dev_state->max_pasids; ++i) {
377		pasid_state = get_pasid_state(dev_state, i);
378		if (pasid_state == NULL)
379			continue;
380
381		put_pasid_state(pasid_state);
382		unbind_pasid(dev_state, i);
383	}
384
385	if (dev_state->pasid_levels == 2)
386		free_pasid_states_level2(dev_state->states);
387	else if (dev_state->pasid_levels == 1)
388		free_pasid_states_level1(dev_state->states);
389	else if (dev_state->pasid_levels != 0)
390		BUG();
391
392	free_page((unsigned long)dev_state->states);
393}
394
395static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
396{
397	return container_of(mn, struct pasid_state, mn);
398}
399
400static void __mn_flush_page(struct mmu_notifier *mn,
401			    unsigned long address)
402{
403	struct pasid_state *pasid_state;
404	struct device_state *dev_state;
405
406	pasid_state = mn_to_state(mn);
407	dev_state   = pasid_state->device_state;
408
409	amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, address);
410}
411
412static int mn_clear_flush_young(struct mmu_notifier *mn,
413				struct mm_struct *mm,
414				unsigned long address)
415{
416	__mn_flush_page(mn, address);
417
418	return 0;
419}
420
421static void mn_change_pte(struct mmu_notifier *mn,
422			  struct mm_struct *mm,
423			  unsigned long address,
424			  pte_t pte)
425{
426	__mn_flush_page(mn, address);
427}
428
429static void mn_invalidate_page(struct mmu_notifier *mn,
430			       struct mm_struct *mm,
431			       unsigned long address)
432{
433	__mn_flush_page(mn, address);
434}
435
436static void mn_invalidate_range_start(struct mmu_notifier *mn,
437				      struct mm_struct *mm,
438				      unsigned long start, unsigned long end)
439{
440	struct pasid_state *pasid_state;
441	struct device_state *dev_state;
442
443	pasid_state = mn_to_state(mn);
444	dev_state   = pasid_state->device_state;
445
446	amd_iommu_domain_set_gcr3(dev_state->domain, pasid_state->pasid,
447				  __pa(empty_page_table));
448}
449
450static void mn_invalidate_range_end(struct mmu_notifier *mn,
451				    struct mm_struct *mm,
452				    unsigned long start, unsigned long end)
453{
454	struct pasid_state *pasid_state;
455	struct device_state *dev_state;
456
457	pasid_state = mn_to_state(mn);
458	dev_state   = pasid_state->device_state;
459
460	amd_iommu_domain_set_gcr3(dev_state->domain, pasid_state->pasid,
461				  __pa(pasid_state->mm->pgd));
462}
463
464static struct mmu_notifier_ops iommu_mn = {
465	.clear_flush_young      = mn_clear_flush_young,
466	.change_pte             = mn_change_pte,
467	.invalidate_page        = mn_invalidate_page,
468	.invalidate_range_start = mn_invalidate_range_start,
469	.invalidate_range_end   = mn_invalidate_range_end,
470};
471
472static void set_pri_tag_status(struct pasid_state *pasid_state,
473			       u16 tag, int status)
474{
475	unsigned long flags;
476
477	spin_lock_irqsave(&pasid_state->lock, flags);
478	pasid_state->pri[tag].status = status;
479	spin_unlock_irqrestore(&pasid_state->lock, flags);
480}
481
482static void finish_pri_tag(struct device_state *dev_state,
483			   struct pasid_state *pasid_state,
484			   u16 tag)
485{
486	unsigned long flags;
487
488	spin_lock_irqsave(&pasid_state->lock, flags);
489	if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
490	    pasid_state->pri[tag].finish) {
491		amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
492				       pasid_state->pri[tag].status, tag);
493		pasid_state->pri[tag].finish = false;
494		pasid_state->pri[tag].status = PPR_SUCCESS;
495	}
496	spin_unlock_irqrestore(&pasid_state->lock, flags);
497}
498
499static void do_fault(struct work_struct *work)
500{
501	struct fault *fault = container_of(work, struct fault, work);
502	int npages, write;
503	struct page *page;
504
505	write = !!(fault->flags & PPR_FAULT_WRITE);
506
507	npages = get_user_pages(fault->state->task, fault->state->mm,
508				fault->address, 1, write, 0, &page, NULL);
509
510	if (npages == 1) {
511		put_page(page);
512	} else if (fault->dev_state->inv_ppr_cb) {
513		int status;
514
515		status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
516						      fault->pasid,
517						      fault->address,
518						      fault->flags);
519		switch (status) {
520		case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
521			set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
522			break;
523		case AMD_IOMMU_INV_PRI_RSP_INVALID:
524			set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
525			break;
526		case AMD_IOMMU_INV_PRI_RSP_FAIL:
527			set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
528			break;
529		default:
530			BUG();
531		}
532	} else {
533		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
534	}
535
536	finish_pri_tag(fault->dev_state, fault->state, fault->tag);
537
538	put_pasid_state(fault->state);
539
540	kfree(fault);
541}
542
543static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
544{
545	struct amd_iommu_fault *iommu_fault;
546	struct pasid_state *pasid_state;
547	struct device_state *dev_state;
548	unsigned long flags;
549	struct fault *fault;
550	bool finish;
551	u16 tag;
552	int ret;
553
554	iommu_fault = data;
555	tag         = iommu_fault->tag & 0x1ff;
556	finish      = (iommu_fault->tag >> 9) & 1;
557
558	ret = NOTIFY_DONE;
559	dev_state = get_device_state(iommu_fault->device_id);
560	if (dev_state == NULL)
561		goto out;
562
563	pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
564	if (pasid_state == NULL) {
565		/* We know the device but not the PASID -> send INVALID */
566		amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
567				       PPR_INVALID, tag);
568		goto out_drop_state;
569	}
570
571	spin_lock_irqsave(&pasid_state->lock, flags);
572	atomic_inc(&pasid_state->pri[tag].inflight);
573	if (finish)
574		pasid_state->pri[tag].finish = true;
575	spin_unlock_irqrestore(&pasid_state->lock, flags);
576
577	fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
578	if (fault == NULL) {
579		/* We are OOM - send success and let the device re-fault */
580		finish_pri_tag(dev_state, pasid_state, tag);
581		goto out_drop_state;
582	}
583
584	fault->dev_state = dev_state;
585	fault->address   = iommu_fault->address;
586	fault->state     = pasid_state;
587	fault->tag       = tag;
588	fault->finish    = finish;
589	fault->flags     = iommu_fault->flags;
590	INIT_WORK(&fault->work, do_fault);
591
592	queue_work(iommu_wq, &fault->work);
593
594	ret = NOTIFY_OK;
595
596out_drop_state:
597	put_device_state(dev_state);
598
599out:
600	return ret;
601}
602
603static struct notifier_block ppr_nb = {
604	.notifier_call = ppr_notifier,
605};
606
607static int task_exit(struct notifier_block *nb, unsigned long e, void *data)
608{
609	struct pasid_state *pasid_state;
610	struct task_struct *task;
611
612	task = data;
613
614	/*
615	 * Using this notifier is a hack - but there is no other choice
616	 * at the moment. What I really want is a sleeping notifier that
617	 * is called when an MM goes down. But such a notifier doesn't
618	 * exist yet. The notifier needs to sleep because it has to make
619	 * sure that the device does not use the PASID and the address
620	 * space anymore before it is destroyed. This includes waiting
621	 * for pending PRI requests to pass the workqueue. The
622	 * MMU-Notifiers would be a good fit, but they use RCU and so
623	 * they are not allowed to sleep. Lets see how we can solve this
624	 * in a more intelligent way in the future.
625	 */
626again:
627	spin_lock(&ps_lock);
628	list_for_each_entry(pasid_state, &pasid_state_list, list) {
629		struct device_state *dev_state;
630		int pasid;
631
632		if (pasid_state->task != task)
633			continue;
634
635		/* Drop Lock and unbind */
636		spin_unlock(&ps_lock);
637
638		dev_state = pasid_state->device_state;
639		pasid     = pasid_state->pasid;
640
641		if (pasid_state->device_state->inv_ctx_cb)
642			dev_state->inv_ctx_cb(dev_state->pdev, pasid);
643
644		unbind_pasid(dev_state, pasid);
645
646		/* Task may be in the list multiple times */
647		goto again;
648	}
649	spin_unlock(&ps_lock);
650
651	return NOTIFY_OK;
652}
653
654int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
655			 struct task_struct *task)
656{
657	struct pasid_state *pasid_state;
658	struct device_state *dev_state;
659	u16 devid;
660	int ret;
661
662	might_sleep();
663
664	if (!amd_iommu_v2_supported())
665		return -ENODEV;
666
667	devid     = device_id(pdev);
668	dev_state = get_device_state(devid);
669
670	if (dev_state == NULL)
671		return -EINVAL;
672
673	ret = -EINVAL;
674	if (pasid < 0 || pasid >= dev_state->max_pasids)
675		goto out;
676
677	ret = -ENOMEM;
678	pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
679	if (pasid_state == NULL)
680		goto out;
681
682	atomic_set(&pasid_state->count, 1);
683	init_waitqueue_head(&pasid_state->wq);
684	pasid_state->task         = task;
685	pasid_state->mm           = get_task_mm(task);
686	pasid_state->device_state = dev_state;
687	pasid_state->pasid        = pasid;
688	pasid_state->mn.ops       = &iommu_mn;
689
690	if (pasid_state->mm == NULL)
691		goto out_free;
692
693	mmu_notifier_register(&pasid_state->mn, pasid_state->mm);
694
695	ret = set_pasid_state(dev_state, pasid_state, pasid);
696	if (ret)
697		goto out_unregister;
698
699	ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
700					__pa(pasid_state->mm->pgd));
701	if (ret)
702		goto out_clear_state;
703
704	link_pasid_state(pasid_state);
705
706	return 0;
707
708out_clear_state:
709	clear_pasid_state(dev_state, pasid);
710
711out_unregister:
712	mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
713
714out_free:
715	free_pasid_state(pasid_state);
716
717out:
718	put_device_state(dev_state);
719
720	return ret;
721}
722EXPORT_SYMBOL(amd_iommu_bind_pasid);
723
724void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
725{
726	struct device_state *dev_state;
727	u16 devid;
728
729	might_sleep();
730
731	if (!amd_iommu_v2_supported())
732		return;
733
734	devid = device_id(pdev);
735	dev_state = get_device_state(devid);
736	if (dev_state == NULL)
737		return;
738
739	if (pasid < 0 || pasid >= dev_state->max_pasids)
740		goto out;
741
742	unbind_pasid(dev_state, pasid);
743
744out:
745	put_device_state(dev_state);
746}
747EXPORT_SYMBOL(amd_iommu_unbind_pasid);
748
749int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
750{
751	struct device_state *dev_state;
752	unsigned long flags;
753	int ret, tmp;
754	u16 devid;
755
756	might_sleep();
757
758	if (!amd_iommu_v2_supported())
759		return -ENODEV;
760
761	if (pasids <= 0 || pasids > (PASID_MASK + 1))
762		return -EINVAL;
763
764	devid = device_id(pdev);
765
766	dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
767	if (dev_state == NULL)
768		return -ENOMEM;
769
770	spin_lock_init(&dev_state->lock);
771	init_waitqueue_head(&dev_state->wq);
772	dev_state->pdev = pdev;
773
774	tmp = pasids;
775	for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
776		dev_state->pasid_levels += 1;
777
778	atomic_set(&dev_state->count, 1);
779	dev_state->max_pasids = pasids;
780
781	ret = -ENOMEM;
782	dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
783	if (dev_state->states == NULL)
784		goto out_free_dev_state;
785
786	dev_state->domain = iommu_domain_alloc(&pci_bus_type);
787	if (dev_state->domain == NULL)
788		goto out_free_states;
789
790	amd_iommu_domain_direct_map(dev_state->domain);
791
792	ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
793	if (ret)
794		goto out_free_domain;
795
796	ret = iommu_attach_device(dev_state->domain, &pdev->dev);
797	if (ret != 0)
798		goto out_free_domain;
799
800	spin_lock_irqsave(&state_lock, flags);
801
802	if (state_table[devid] != NULL) {
803		spin_unlock_irqrestore(&state_lock, flags);
804		ret = -EBUSY;
805		goto out_free_domain;
806	}
807
808	state_table[devid] = dev_state;
809
810	spin_unlock_irqrestore(&state_lock, flags);
811
812	return 0;
813
814out_free_domain:
815	iommu_domain_free(dev_state->domain);
816
817out_free_states:
818	free_page((unsigned long)dev_state->states);
819
820out_free_dev_state:
821	kfree(dev_state);
822
823	return ret;
824}
825EXPORT_SYMBOL(amd_iommu_init_device);
826
827void amd_iommu_free_device(struct pci_dev *pdev)
828{
829	struct device_state *dev_state;
830	unsigned long flags;
831	u16 devid;
832
833	if (!amd_iommu_v2_supported())
834		return;
835
836	devid = device_id(pdev);
837
838	spin_lock_irqsave(&state_lock, flags);
839
840	dev_state = state_table[devid];
841	if (dev_state == NULL) {
842		spin_unlock_irqrestore(&state_lock, flags);
843		return;
844	}
845
846	state_table[devid] = NULL;
847
848	spin_unlock_irqrestore(&state_lock, flags);
849
850	/* Get rid of any remaining pasid states */
851	free_pasid_states(dev_state);
852
853	put_device_state_wait(dev_state);
854}
855EXPORT_SYMBOL(amd_iommu_free_device);
856
857int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
858				 amd_iommu_invalid_ppr_cb cb)
859{
860	struct device_state *dev_state;
861	unsigned long flags;
862	u16 devid;
863	int ret;
864
865	if (!amd_iommu_v2_supported())
866		return -ENODEV;
867
868	devid = device_id(pdev);
869
870	spin_lock_irqsave(&state_lock, flags);
871
872	ret = -EINVAL;
873	dev_state = state_table[devid];
874	if (dev_state == NULL)
875		goto out_unlock;
876
877	dev_state->inv_ppr_cb = cb;
878
879	ret = 0;
880
881out_unlock:
882	spin_unlock_irqrestore(&state_lock, flags);
883
884	return ret;
885}
886EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
887
888int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
889				    amd_iommu_invalidate_ctx cb)
890{
891	struct device_state *dev_state;
892	unsigned long flags;
893	u16 devid;
894	int ret;
895
896	if (!amd_iommu_v2_supported())
897		return -ENODEV;
898
899	devid = device_id(pdev);
900
901	spin_lock_irqsave(&state_lock, flags);
902
903	ret = -EINVAL;
904	dev_state = state_table[devid];
905	if (dev_state == NULL)
906		goto out_unlock;
907
908	dev_state->inv_ctx_cb = cb;
909
910	ret = 0;
911
912out_unlock:
913	spin_unlock_irqrestore(&state_lock, flags);
914
915	return ret;
916}
917EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
918
919static int __init amd_iommu_v2_init(void)
920{
921	size_t state_table_size;
922	int ret;
923
924	pr_info("AMD IOMMUv2 driver by Joerg Roedel <joerg.roedel@amd.com>");
925
926	spin_lock_init(&state_lock);
927
928	state_table_size = MAX_DEVICES * sizeof(struct device_state *);
929	state_table = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO,
930					       get_order(state_table_size));
931	if (state_table == NULL)
932		return -ENOMEM;
933
934	ret = -ENOMEM;
935	iommu_wq = create_workqueue("amd_iommu_v2");
936	if (iommu_wq == NULL)
937		goto out_free;
938
939	ret = -ENOMEM;
940	empty_page_table = (u64 *)get_zeroed_page(GFP_KERNEL);
941	if (empty_page_table == NULL)
942		goto out_destroy_wq;
943
944	amd_iommu_register_ppr_notifier(&ppr_nb);
945	profile_event_register(PROFILE_TASK_EXIT, &profile_nb);
946
947	return 0;
948
949out_destroy_wq:
950	destroy_workqueue(iommu_wq);
951
952out_free:
953	free_pages((unsigned long)state_table, get_order(state_table_size));
954
955	return ret;
956}
957
958static void __exit amd_iommu_v2_exit(void)
959{
960	struct device_state *dev_state;
961	size_t state_table_size;
962	int i;
963
964	profile_event_unregister(PROFILE_TASK_EXIT, &profile_nb);
965	amd_iommu_unregister_ppr_notifier(&ppr_nb);
966
967	flush_workqueue(iommu_wq);
968
969	/*
970	 * The loop below might call flush_workqueue(), so call
971	 * destroy_workqueue() after it
972	 */
973	for (i = 0; i < MAX_DEVICES; ++i) {
974		dev_state = get_device_state(i);
975
976		if (dev_state == NULL)
977			continue;
978
979		WARN_ON_ONCE(1);
980
981		put_device_state(dev_state);
982		amd_iommu_free_device(dev_state->pdev);
983	}
984
985	destroy_workqueue(iommu_wq);
986
987	state_table_size = MAX_DEVICES * sizeof(struct device_state *);
988	free_pages((unsigned long)state_table, get_order(state_table_size));
989
990	free_page((unsigned long)empty_page_table);
991}
992
993module_init(amd_iommu_v2_init);
994module_exit(amd_iommu_v2_exit);
995