auth_gss.c revision dc5ddce956660247e004a4b20a26b7d137ab1644
1/*
2 * linux/net/sunrpc/auth_gss/auth_gss.c
3 *
4 * RPCSEC_GSS client authentication.
5 *
6 *  Copyright (c) 2000 The Regents of the University of Michigan.
7 *  All rights reserved.
8 *
9 *  Dug Song       <dugsong@monkey.org>
10 *  Andy Adamson   <andros@umich.edu>
11 *
12 *  Redistribution and use in source and binary forms, with or without
13 *  modification, are permitted provided that the following conditions
14 *  are met:
15 *
16 *  1. Redistributions of source code must retain the above copyright
17 *     notice, this list of conditions and the following disclaimer.
18 *  2. Redistributions in binary form must reproduce the above copyright
19 *     notice, this list of conditions and the following disclaimer in the
20 *     documentation and/or other materials provided with the distribution.
21 *  3. Neither the name of the University nor the names of its
22 *     contributors may be used to endorse or promote products derived
23 *     from this software without specific prior written permission.
24 *
25 *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
26 *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27 *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
29 *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30 *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31 *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
32 *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 */
37
38
39#include <linux/module.h>
40#include <linux/init.h>
41#include <linux/types.h>
42#include <linux/slab.h>
43#include <linux/sched.h>
44#include <linux/pagemap.h>
45#include <linux/sunrpc/clnt.h>
46#include <linux/sunrpc/auth.h>
47#include <linux/sunrpc/auth_gss.h>
48#include <linux/sunrpc/svcauth_gss.h>
49#include <linux/sunrpc/gss_err.h>
50#include <linux/workqueue.h>
51#include <linux/sunrpc/rpc_pipe_fs.h>
52#include <linux/sunrpc/gss_api.h>
53#include <asm/uaccess.h>
54
55static const struct rpc_authops authgss_ops;
56
57static const struct rpc_credops gss_credops;
58static const struct rpc_credops gss_nullops;
59
60#ifdef RPC_DEBUG
61# define RPCDBG_FACILITY	RPCDBG_AUTH
62#endif
63
64#define GSS_CRED_SLACK		1024
65/* length of a krb5 verifier (48), plus data added before arguments when
66 * using integrity (two 4-byte integers): */
67#define GSS_VERF_SLACK		100
68
69struct gss_auth {
70	struct kref kref;
71	struct rpc_auth rpc_auth;
72	struct gss_api_mech *mech;
73	enum rpc_gss_svc service;
74	struct rpc_clnt *client;
75	/*
76	 * There are two upcall pipes; dentry[1], named "gssd", is used
77	 * for the new text-based upcall; dentry[0] is named after the
78	 * mechanism (for example, "krb5") and exists for
79	 * backwards-compatibility with older gssd's.
80	 */
81	struct dentry *dentry[2];
82};
83
84/* pipe_version >= 0 if and only if someone has a pipe open. */
85static int pipe_version = -1;
86static atomic_t pipe_users = ATOMIC_INIT(0);
87static DEFINE_SPINLOCK(pipe_version_lock);
88static struct rpc_wait_queue pipe_version_rpc_waitqueue;
89static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
90
91static void gss_free_ctx(struct gss_cl_ctx *);
92static const struct rpc_pipe_ops gss_upcall_ops_v0;
93static const struct rpc_pipe_ops gss_upcall_ops_v1;
94
95static inline struct gss_cl_ctx *
96gss_get_ctx(struct gss_cl_ctx *ctx)
97{
98	atomic_inc(&ctx->count);
99	return ctx;
100}
101
102static inline void
103gss_put_ctx(struct gss_cl_ctx *ctx)
104{
105	if (atomic_dec_and_test(&ctx->count))
106		gss_free_ctx(ctx);
107}
108
109/* gss_cred_set_ctx:
110 * called by gss_upcall_callback and gss_create_upcall in order
111 * to set the gss context. The actual exchange of an old context
112 * and a new one is protected by the inode->i_lock.
113 */
114static void
115gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
116{
117	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
118
119	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
120		return;
121	gss_get_ctx(ctx);
122	rcu_assign_pointer(gss_cred->gc_ctx, ctx);
123	set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
124	smp_mb__before_clear_bit();
125	clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
126}
127
128static const void *
129simple_get_bytes(const void *p, const void *end, void *res, size_t len)
130{
131	const void *q = (const void *)((const char *)p + len);
132	if (unlikely(q > end || q < p))
133		return ERR_PTR(-EFAULT);
134	memcpy(res, p, len);
135	return q;
136}
137
138static inline const void *
139simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
140{
141	const void *q;
142	unsigned int len;
143
144	p = simple_get_bytes(p, end, &len, sizeof(len));
145	if (IS_ERR(p))
146		return p;
147	q = (const void *)((const char *)p + len);
148	if (unlikely(q > end || q < p))
149		return ERR_PTR(-EFAULT);
150	dest->data = kmemdup(p, len, GFP_NOFS);
151	if (unlikely(dest->data == NULL))
152		return ERR_PTR(-ENOMEM);
153	dest->len = len;
154	return q;
155}
156
157static struct gss_cl_ctx *
158gss_cred_get_ctx(struct rpc_cred *cred)
159{
160	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
161	struct gss_cl_ctx *ctx = NULL;
162
163	rcu_read_lock();
164	if (gss_cred->gc_ctx)
165		ctx = gss_get_ctx(gss_cred->gc_ctx);
166	rcu_read_unlock();
167	return ctx;
168}
169
170static struct gss_cl_ctx *
171gss_alloc_context(void)
172{
173	struct gss_cl_ctx *ctx;
174
175	ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
176	if (ctx != NULL) {
177		ctx->gc_proc = RPC_GSS_PROC_DATA;
178		ctx->gc_seq = 1;	/* NetApp 6.4R1 doesn't accept seq. no. 0 */
179		spin_lock_init(&ctx->gc_seq_lock);
180		atomic_set(&ctx->count,1);
181	}
182	return ctx;
183}
184
185#define GSSD_MIN_TIMEOUT (60 * 60)
186static const void *
187gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
188{
189	const void *q;
190	unsigned int seclen;
191	unsigned int timeout;
192	u32 window_size;
193	int ret;
194
195	/* First unsigned int gives the lifetime (in seconds) of the cred */
196	p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
197	if (IS_ERR(p))
198		goto err;
199	if (timeout == 0)
200		timeout = GSSD_MIN_TIMEOUT;
201	ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4;
202	/* Sequence number window. Determines the maximum number of simultaneous requests */
203	p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
204	if (IS_ERR(p))
205		goto err;
206	ctx->gc_win = window_size;
207	/* gssd signals an error by passing ctx->gc_win = 0: */
208	if (ctx->gc_win == 0) {
209		/*
210		 * in which case, p points to an error code. Anything other
211		 * than -EKEYEXPIRED gets converted to -EACCES.
212		 */
213		p = simple_get_bytes(p, end, &ret, sizeof(ret));
214		if (!IS_ERR(p))
215			p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) :
216						    ERR_PTR(-EACCES);
217		goto err;
218	}
219	/* copy the opaque wire context */
220	p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
221	if (IS_ERR(p))
222		goto err;
223	/* import the opaque security context */
224	p  = simple_get_bytes(p, end, &seclen, sizeof(seclen));
225	if (IS_ERR(p))
226		goto err;
227	q = (const void *)((const char *)p + seclen);
228	if (unlikely(q > end || q < p)) {
229		p = ERR_PTR(-EFAULT);
230		goto err;
231	}
232	ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx);
233	if (ret < 0) {
234		p = ERR_PTR(ret);
235		goto err;
236	}
237	return q;
238err:
239	dprintk("RPC:       gss_fill_context returning %ld\n", -PTR_ERR(p));
240	return p;
241}
242
243#define UPCALL_BUF_LEN 128
244
245struct gss_upcall_msg {
246	atomic_t count;
247	uid_t	uid;
248	struct rpc_pipe_msg msg;
249	struct list_head list;
250	struct gss_auth *auth;
251	struct rpc_inode *inode;
252	struct rpc_wait_queue rpc_waitqueue;
253	wait_queue_head_t waitqueue;
254	struct gss_cl_ctx *ctx;
255	char databuf[UPCALL_BUF_LEN];
256};
257
258static int get_pipe_version(void)
259{
260	int ret;
261
262	spin_lock(&pipe_version_lock);
263	if (pipe_version >= 0) {
264		atomic_inc(&pipe_users);
265		ret = pipe_version;
266	} else
267		ret = -EAGAIN;
268	spin_unlock(&pipe_version_lock);
269	return ret;
270}
271
272static void put_pipe_version(void)
273{
274	if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
275		pipe_version = -1;
276		spin_unlock(&pipe_version_lock);
277	}
278}
279
280static void
281gss_release_msg(struct gss_upcall_msg *gss_msg)
282{
283	if (!atomic_dec_and_test(&gss_msg->count))
284		return;
285	put_pipe_version();
286	BUG_ON(!list_empty(&gss_msg->list));
287	if (gss_msg->ctx != NULL)
288		gss_put_ctx(gss_msg->ctx);
289	rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
290	kfree(gss_msg);
291}
292
293static struct gss_upcall_msg *
294__gss_find_upcall(struct rpc_inode *rpci, uid_t uid)
295{
296	struct gss_upcall_msg *pos;
297	list_for_each_entry(pos, &rpci->in_downcall, list) {
298		if (pos->uid != uid)
299			continue;
300		atomic_inc(&pos->count);
301		dprintk("RPC:       gss_find_upcall found msg %p\n", pos);
302		return pos;
303	}
304	dprintk("RPC:       gss_find_upcall found nothing\n");
305	return NULL;
306}
307
308/* Try to add an upcall to the pipefs queue.
309 * If an upcall owned by our uid already exists, then we return a reference
310 * to that upcall instead of adding the new upcall.
311 */
312static inline struct gss_upcall_msg *
313gss_add_msg(struct gss_upcall_msg *gss_msg)
314{
315	struct rpc_inode *rpci = gss_msg->inode;
316	struct inode *inode = &rpci->vfs_inode;
317	struct gss_upcall_msg *old;
318
319	spin_lock(&inode->i_lock);
320	old = __gss_find_upcall(rpci, gss_msg->uid);
321	if (old == NULL) {
322		atomic_inc(&gss_msg->count);
323		list_add(&gss_msg->list, &rpci->in_downcall);
324	} else
325		gss_msg = old;
326	spin_unlock(&inode->i_lock);
327	return gss_msg;
328}
329
330static void
331__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
332{
333	list_del_init(&gss_msg->list);
334	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
335	wake_up_all(&gss_msg->waitqueue);
336	atomic_dec(&gss_msg->count);
337}
338
339static void
340gss_unhash_msg(struct gss_upcall_msg *gss_msg)
341{
342	struct inode *inode = &gss_msg->inode->vfs_inode;
343
344	if (list_empty(&gss_msg->list))
345		return;
346	spin_lock(&inode->i_lock);
347	if (!list_empty(&gss_msg->list))
348		__gss_unhash_msg(gss_msg);
349	spin_unlock(&inode->i_lock);
350}
351
352static void
353gss_upcall_callback(struct rpc_task *task)
354{
355	struct gss_cred *gss_cred = container_of(task->tk_msg.rpc_cred,
356			struct gss_cred, gc_base);
357	struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
358	struct inode *inode = &gss_msg->inode->vfs_inode;
359
360	spin_lock(&inode->i_lock);
361	if (gss_msg->ctx)
362		gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx);
363	else
364		task->tk_status = gss_msg->msg.errno;
365	gss_cred->gc_upcall = NULL;
366	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
367	spin_unlock(&inode->i_lock);
368	gss_release_msg(gss_msg);
369}
370
371static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
372{
373	gss_msg->msg.data = &gss_msg->uid;
374	gss_msg->msg.len = sizeof(gss_msg->uid);
375}
376
377static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
378				struct rpc_clnt *clnt, int machine_cred)
379{
380	char *p = gss_msg->databuf;
381	int len = 0;
382
383	gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
384				   gss_msg->auth->mech->gm_name,
385				   gss_msg->uid);
386	p += gss_msg->msg.len;
387	if (clnt->cl_principal) {
388		len = sprintf(p, "target=%s ", clnt->cl_principal);
389		p += len;
390		gss_msg->msg.len += len;
391	}
392	if (machine_cred) {
393		len = sprintf(p, "service=* ");
394		p += len;
395		gss_msg->msg.len += len;
396	} else if (!strcmp(clnt->cl_program->name, "nfs4_cb")) {
397		len = sprintf(p, "service=nfs ");
398		p += len;
399		gss_msg->msg.len += len;
400	}
401	len = sprintf(p, "\n");
402	gss_msg->msg.len += len;
403
404	gss_msg->msg.data = gss_msg->databuf;
405	BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
406}
407
408static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
409				struct rpc_clnt *clnt, int machine_cred)
410{
411	if (pipe_version == 0)
412		gss_encode_v0_msg(gss_msg);
413	else /* pipe_version == 1 */
414		gss_encode_v1_msg(gss_msg, clnt, machine_cred);
415}
416
417static inline struct gss_upcall_msg *
418gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid, struct rpc_clnt *clnt,
419		int machine_cred)
420{
421	struct gss_upcall_msg *gss_msg;
422	int vers;
423
424	gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
425	if (gss_msg == NULL)
426		return ERR_PTR(-ENOMEM);
427	vers = get_pipe_version();
428	if (vers < 0) {
429		kfree(gss_msg);
430		return ERR_PTR(vers);
431	}
432	gss_msg->inode = RPC_I(gss_auth->dentry[vers]->d_inode);
433	INIT_LIST_HEAD(&gss_msg->list);
434	rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
435	init_waitqueue_head(&gss_msg->waitqueue);
436	atomic_set(&gss_msg->count, 1);
437	gss_msg->uid = uid;
438	gss_msg->auth = gss_auth;
439	gss_encode_msg(gss_msg, clnt, machine_cred);
440	return gss_msg;
441}
442
443static struct gss_upcall_msg *
444gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
445{
446	struct gss_cred *gss_cred = container_of(cred,
447			struct gss_cred, gc_base);
448	struct gss_upcall_msg *gss_new, *gss_msg;
449	uid_t uid = cred->cr_uid;
450
451	gss_new = gss_alloc_msg(gss_auth, uid, clnt, gss_cred->gc_machine_cred);
452	if (IS_ERR(gss_new))
453		return gss_new;
454	gss_msg = gss_add_msg(gss_new);
455	if (gss_msg == gss_new) {
456		struct inode *inode = &gss_new->inode->vfs_inode;
457		int res = rpc_queue_upcall(inode, &gss_new->msg);
458		if (res) {
459			gss_unhash_msg(gss_new);
460			gss_msg = ERR_PTR(res);
461		}
462	} else
463		gss_release_msg(gss_new);
464	return gss_msg;
465}
466
467static void warn_gssd(void)
468{
469	static unsigned long ratelimit;
470	unsigned long now = jiffies;
471
472	if (time_after(now, ratelimit)) {
473		printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
474				"Please check user daemon is running.\n");
475		ratelimit = now + 15*HZ;
476	}
477}
478
479static inline int
480gss_refresh_upcall(struct rpc_task *task)
481{
482	struct rpc_cred *cred = task->tk_msg.rpc_cred;
483	struct gss_auth *gss_auth = container_of(cred->cr_auth,
484			struct gss_auth, rpc_auth);
485	struct gss_cred *gss_cred = container_of(cred,
486			struct gss_cred, gc_base);
487	struct gss_upcall_msg *gss_msg;
488	struct inode *inode;
489	int err = 0;
490
491	dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid,
492								cred->cr_uid);
493	gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
494	if (PTR_ERR(gss_msg) == -EAGAIN) {
495		/* XXX: warning on the first, under the assumption we
496		 * shouldn't normally hit this case on a refresh. */
497		warn_gssd();
498		task->tk_timeout = 15*HZ;
499		rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
500		return 0;
501	}
502	if (IS_ERR(gss_msg)) {
503		err = PTR_ERR(gss_msg);
504		goto out;
505	}
506	inode = &gss_msg->inode->vfs_inode;
507	spin_lock(&inode->i_lock);
508	if (gss_cred->gc_upcall != NULL)
509		rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
510	else if (gss_msg->ctx != NULL) {
511		gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx);
512		gss_cred->gc_upcall = NULL;
513		rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
514	} else if (gss_msg->msg.errno >= 0) {
515		task->tk_timeout = 0;
516		gss_cred->gc_upcall = gss_msg;
517		/* gss_upcall_callback will release the reference to gss_upcall_msg */
518		atomic_inc(&gss_msg->count);
519		rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
520	} else
521		err = gss_msg->msg.errno;
522	spin_unlock(&inode->i_lock);
523	gss_release_msg(gss_msg);
524out:
525	dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n",
526			task->tk_pid, cred->cr_uid, err);
527	return err;
528}
529
530static inline int
531gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
532{
533	struct inode *inode;
534	struct rpc_cred *cred = &gss_cred->gc_base;
535	struct gss_upcall_msg *gss_msg;
536	DEFINE_WAIT(wait);
537	int err = 0;
538
539	dprintk("RPC:       gss_upcall for uid %u\n", cred->cr_uid);
540retry:
541	gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
542	if (PTR_ERR(gss_msg) == -EAGAIN) {
543		err = wait_event_interruptible_timeout(pipe_version_waitqueue,
544				pipe_version >= 0, 15*HZ);
545		if (err)
546			goto out;
547		if (pipe_version < 0)
548			warn_gssd();
549		goto retry;
550	}
551	if (IS_ERR(gss_msg)) {
552		err = PTR_ERR(gss_msg);
553		goto out;
554	}
555	inode = &gss_msg->inode->vfs_inode;
556	for (;;) {
557		prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE);
558		spin_lock(&inode->i_lock);
559		if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
560			break;
561		}
562		spin_unlock(&inode->i_lock);
563		if (signalled()) {
564			err = -ERESTARTSYS;
565			goto out_intr;
566		}
567		schedule();
568	}
569	if (gss_msg->ctx)
570		gss_cred_set_ctx(cred, gss_msg->ctx);
571	else
572		err = gss_msg->msg.errno;
573	spin_unlock(&inode->i_lock);
574out_intr:
575	finish_wait(&gss_msg->waitqueue, &wait);
576	gss_release_msg(gss_msg);
577out:
578	dprintk("RPC:       gss_create_upcall for uid %u result %d\n",
579			cred->cr_uid, err);
580	return err;
581}
582
583static ssize_t
584gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
585		char __user *dst, size_t buflen)
586{
587	char *data = (char *)msg->data + msg->copied;
588	size_t mlen = min(msg->len, buflen);
589	unsigned long left;
590
591	left = copy_to_user(dst, data, mlen);
592	if (left == mlen) {
593		msg->errno = -EFAULT;
594		return -EFAULT;
595	}
596
597	mlen -= left;
598	msg->copied += mlen;
599	msg->errno = 0;
600	return mlen;
601}
602
603#define MSG_BUF_MAXSIZE 1024
604
605static ssize_t
606gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
607{
608	const void *p, *end;
609	void *buf;
610	struct gss_upcall_msg *gss_msg;
611	struct inode *inode = filp->f_path.dentry->d_inode;
612	struct gss_cl_ctx *ctx;
613	uid_t uid;
614	ssize_t err = -EFBIG;
615
616	if (mlen > MSG_BUF_MAXSIZE)
617		goto out;
618	err = -ENOMEM;
619	buf = kmalloc(mlen, GFP_NOFS);
620	if (!buf)
621		goto out;
622
623	err = -EFAULT;
624	if (copy_from_user(buf, src, mlen))
625		goto err;
626
627	end = (const void *)((char *)buf + mlen);
628	p = simple_get_bytes(buf, end, &uid, sizeof(uid));
629	if (IS_ERR(p)) {
630		err = PTR_ERR(p);
631		goto err;
632	}
633
634	err = -ENOMEM;
635	ctx = gss_alloc_context();
636	if (ctx == NULL)
637		goto err;
638
639	err = -ENOENT;
640	/* Find a matching upcall */
641	spin_lock(&inode->i_lock);
642	gss_msg = __gss_find_upcall(RPC_I(inode), uid);
643	if (gss_msg == NULL) {
644		spin_unlock(&inode->i_lock);
645		goto err_put_ctx;
646	}
647	list_del_init(&gss_msg->list);
648	spin_unlock(&inode->i_lock);
649
650	p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
651	if (IS_ERR(p)) {
652		err = PTR_ERR(p);
653		switch (err) {
654		case -EACCES:
655		case -EKEYEXPIRED:
656			gss_msg->msg.errno = err;
657			err = mlen;
658			break;
659		case -EFAULT:
660		case -ENOMEM:
661		case -EINVAL:
662		case -ENOSYS:
663			gss_msg->msg.errno = -EAGAIN;
664			break;
665		default:
666			printk(KERN_CRIT "%s: bad return from "
667				"gss_fill_context: %zd\n", __func__, err);
668			BUG();
669		}
670		goto err_release_msg;
671	}
672	gss_msg->ctx = gss_get_ctx(ctx);
673	err = mlen;
674
675err_release_msg:
676	spin_lock(&inode->i_lock);
677	__gss_unhash_msg(gss_msg);
678	spin_unlock(&inode->i_lock);
679	gss_release_msg(gss_msg);
680err_put_ctx:
681	gss_put_ctx(ctx);
682err:
683	kfree(buf);
684out:
685	dprintk("RPC:       gss_pipe_downcall returning %Zd\n", err);
686	return err;
687}
688
689static int gss_pipe_open(struct inode *inode, int new_version)
690{
691	int ret = 0;
692
693	spin_lock(&pipe_version_lock);
694	if (pipe_version < 0) {
695		/* First open of any gss pipe determines the version: */
696		pipe_version = new_version;
697		rpc_wake_up(&pipe_version_rpc_waitqueue);
698		wake_up(&pipe_version_waitqueue);
699	} else if (pipe_version != new_version) {
700		/* Trying to open a pipe of a different version */
701		ret = -EBUSY;
702		goto out;
703	}
704	atomic_inc(&pipe_users);
705out:
706	spin_unlock(&pipe_version_lock);
707	return ret;
708
709}
710
711static int gss_pipe_open_v0(struct inode *inode)
712{
713	return gss_pipe_open(inode, 0);
714}
715
716static int gss_pipe_open_v1(struct inode *inode)
717{
718	return gss_pipe_open(inode, 1);
719}
720
721static void
722gss_pipe_release(struct inode *inode)
723{
724	struct rpc_inode *rpci = RPC_I(inode);
725	struct gss_upcall_msg *gss_msg;
726
727	spin_lock(&inode->i_lock);
728	while (!list_empty(&rpci->in_downcall)) {
729
730		gss_msg = list_entry(rpci->in_downcall.next,
731				struct gss_upcall_msg, list);
732		gss_msg->msg.errno = -EPIPE;
733		atomic_inc(&gss_msg->count);
734		__gss_unhash_msg(gss_msg);
735		spin_unlock(&inode->i_lock);
736		gss_release_msg(gss_msg);
737		spin_lock(&inode->i_lock);
738	}
739	spin_unlock(&inode->i_lock);
740
741	put_pipe_version();
742}
743
744static void
745gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
746{
747	struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
748
749	if (msg->errno < 0) {
750		dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
751				gss_msg);
752		atomic_inc(&gss_msg->count);
753		gss_unhash_msg(gss_msg);
754		if (msg->errno == -ETIMEDOUT)
755			warn_gssd();
756		gss_release_msg(gss_msg);
757	}
758}
759
760/*
761 * NOTE: we have the opportunity to use different
762 * parameters based on the input flavor (which must be a pseudoflavor)
763 */
764static struct rpc_auth *
765gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
766{
767	struct gss_auth *gss_auth;
768	struct rpc_auth * auth;
769	int err = -ENOMEM; /* XXX? */
770
771	dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
772
773	if (!try_module_get(THIS_MODULE))
774		return ERR_PTR(err);
775	if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
776		goto out_dec;
777	gss_auth->client = clnt;
778	err = -EINVAL;
779	gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
780	if (!gss_auth->mech) {
781		printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n",
782				__func__, flavor);
783		goto err_free;
784	}
785	gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
786	if (gss_auth->service == 0)
787		goto err_put_mech;
788	auth = &gss_auth->rpc_auth;
789	auth->au_cslack = GSS_CRED_SLACK >> 2;
790	auth->au_rslack = GSS_VERF_SLACK >> 2;
791	auth->au_ops = &authgss_ops;
792	auth->au_flavor = flavor;
793	atomic_set(&auth->au_count, 1);
794	kref_init(&gss_auth->kref);
795
796	/*
797	 * Note: if we created the old pipe first, then someone who
798	 * examined the directory at the right moment might conclude
799	 * that we supported only the old pipe.  So we instead create
800	 * the new pipe first.
801	 */
802	gss_auth->dentry[1] = rpc_mkpipe(clnt->cl_path.dentry,
803					 "gssd",
804					 clnt, &gss_upcall_ops_v1,
805					 RPC_PIPE_WAIT_FOR_OPEN);
806	if (IS_ERR(gss_auth->dentry[1])) {
807		err = PTR_ERR(gss_auth->dentry[1]);
808		goto err_put_mech;
809	}
810
811	gss_auth->dentry[0] = rpc_mkpipe(clnt->cl_path.dentry,
812					 gss_auth->mech->gm_name,
813					 clnt, &gss_upcall_ops_v0,
814					 RPC_PIPE_WAIT_FOR_OPEN);
815	if (IS_ERR(gss_auth->dentry[0])) {
816		err = PTR_ERR(gss_auth->dentry[0]);
817		goto err_unlink_pipe_1;
818	}
819	err = rpcauth_init_credcache(auth);
820	if (err)
821		goto err_unlink_pipe_0;
822
823	return auth;
824err_unlink_pipe_0:
825	rpc_unlink(gss_auth->dentry[0]);
826err_unlink_pipe_1:
827	rpc_unlink(gss_auth->dentry[1]);
828err_put_mech:
829	gss_mech_put(gss_auth->mech);
830err_free:
831	kfree(gss_auth);
832out_dec:
833	module_put(THIS_MODULE);
834	return ERR_PTR(err);
835}
836
837static void
838gss_free(struct gss_auth *gss_auth)
839{
840	rpc_unlink(gss_auth->dentry[1]);
841	rpc_unlink(gss_auth->dentry[0]);
842	gss_mech_put(gss_auth->mech);
843
844	kfree(gss_auth);
845	module_put(THIS_MODULE);
846}
847
848static void
849gss_free_callback(struct kref *kref)
850{
851	struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
852
853	gss_free(gss_auth);
854}
855
856static void
857gss_destroy(struct rpc_auth *auth)
858{
859	struct gss_auth *gss_auth;
860
861	dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
862			auth, auth->au_flavor);
863
864	rpcauth_destroy_credcache(auth);
865
866	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
867	kref_put(&gss_auth->kref, gss_free_callback);
868}
869
870/*
871 * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call
872 * to the server with the GSS control procedure field set to
873 * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
874 * all RPCSEC_GSS state associated with that context.
875 */
876static int
877gss_destroying_context(struct rpc_cred *cred)
878{
879	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
880	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
881	struct rpc_task *task;
882
883	if (gss_cred->gc_ctx == NULL ||
884	    test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
885		return 0;
886
887	gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
888	cred->cr_ops = &gss_nullops;
889
890	/* Take a reference to ensure the cred will be destroyed either
891	 * by the RPC call or by the put_rpccred() below */
892	get_rpccred(cred);
893
894	task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT);
895	if (!IS_ERR(task))
896		rpc_put_task(task);
897
898	put_rpccred(cred);
899	return 1;
900}
901
902/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
903 * to create a new cred or context, so they check that things have been
904 * allocated before freeing them. */
905static void
906gss_do_free_ctx(struct gss_cl_ctx *ctx)
907{
908	dprintk("RPC:       gss_free_ctx\n");
909
910	kfree(ctx->gc_wire_ctx.data);
911	kfree(ctx);
912}
913
914static void
915gss_free_ctx_callback(struct rcu_head *head)
916{
917	struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
918	gss_do_free_ctx(ctx);
919}
920
921static void
922gss_free_ctx(struct gss_cl_ctx *ctx)
923{
924	struct gss_ctx *gc_gss_ctx;
925
926	gc_gss_ctx = rcu_dereference(ctx->gc_gss_ctx);
927	rcu_assign_pointer(ctx->gc_gss_ctx, NULL);
928	call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
929	if (gc_gss_ctx)
930		gss_delete_sec_context(&gc_gss_ctx);
931}
932
933static void
934gss_free_cred(struct gss_cred *gss_cred)
935{
936	dprintk("RPC:       gss_free_cred %p\n", gss_cred);
937	kfree(gss_cred);
938}
939
940static void
941gss_free_cred_callback(struct rcu_head *head)
942{
943	struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
944	gss_free_cred(gss_cred);
945}
946
947static void
948gss_destroy_nullcred(struct rpc_cred *cred)
949{
950	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
951	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
952	struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
953
954	rcu_assign_pointer(gss_cred->gc_ctx, NULL);
955	call_rcu(&cred->cr_rcu, gss_free_cred_callback);
956	if (ctx)
957		gss_put_ctx(ctx);
958	kref_put(&gss_auth->kref, gss_free_callback);
959}
960
961static void
962gss_destroy_cred(struct rpc_cred *cred)
963{
964
965	if (gss_destroying_context(cred))
966		return;
967	gss_destroy_nullcred(cred);
968}
969
970/*
971 * Lookup RPCSEC_GSS cred for the current process
972 */
973static struct rpc_cred *
974gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
975{
976	return rpcauth_lookup_credcache(auth, acred, flags);
977}
978
979static struct rpc_cred *
980gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
981{
982	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
983	struct gss_cred	*cred = NULL;
984	int err = -ENOMEM;
985
986	dprintk("RPC:       gss_create_cred for uid %d, flavor %d\n",
987		acred->uid, auth->au_flavor);
988
989	if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
990		goto out_err;
991
992	rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
993	/*
994	 * Note: in order to force a call to call_refresh(), we deliberately
995	 * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
996	 */
997	cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
998	cred->gc_service = gss_auth->service;
999	cred->gc_machine_cred = acred->machine_cred;
1000	kref_get(&gss_auth->kref);
1001	return &cred->gc_base;
1002
1003out_err:
1004	dprintk("RPC:       gss_create_cred failed with error %d\n", err);
1005	return ERR_PTR(err);
1006}
1007
1008static int
1009gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
1010{
1011	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1012	struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
1013	int err;
1014
1015	do {
1016		err = gss_create_upcall(gss_auth, gss_cred);
1017	} while (err == -EAGAIN);
1018	return err;
1019}
1020
1021static int
1022gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
1023{
1024	struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
1025
1026	if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
1027		goto out;
1028	/* Don't match with creds that have expired. */
1029	if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
1030		return 0;
1031	if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
1032		return 0;
1033out:
1034	if (acred->machine_cred != gss_cred->gc_machine_cred)
1035		return 0;
1036	return (rc->cr_uid == acred->uid);
1037}
1038
1039/*
1040* Marshal credentials.
1041* Maybe we should keep a cached credential for performance reasons.
1042*/
1043static __be32 *
1044gss_marshal(struct rpc_task *task, __be32 *p)
1045{
1046	struct rpc_cred *cred = task->tk_msg.rpc_cred;
1047	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1048						 gc_base);
1049	struct gss_cl_ctx	*ctx = gss_cred_get_ctx(cred);
1050	__be32		*cred_len;
1051	struct rpc_rqst *req = task->tk_rqstp;
1052	u32             maj_stat = 0;
1053	struct xdr_netobj mic;
1054	struct kvec	iov;
1055	struct xdr_buf	verf_buf;
1056
1057	dprintk("RPC: %5u gss_marshal\n", task->tk_pid);
1058
1059	*p++ = htonl(RPC_AUTH_GSS);
1060	cred_len = p++;
1061
1062	spin_lock(&ctx->gc_seq_lock);
1063	req->rq_seqno = ctx->gc_seq++;
1064	spin_unlock(&ctx->gc_seq_lock);
1065
1066	*p++ = htonl((u32) RPC_GSS_VERSION);
1067	*p++ = htonl((u32) ctx->gc_proc);
1068	*p++ = htonl((u32) req->rq_seqno);
1069	*p++ = htonl((u32) gss_cred->gc_service);
1070	p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
1071	*cred_len = htonl((p - (cred_len + 1)) << 2);
1072
1073	/* We compute the checksum for the verifier over the xdr-encoded bytes
1074	 * starting with the xid and ending at the end of the credential: */
1075	iov.iov_base = xprt_skip_transport_header(task->tk_xprt,
1076					req->rq_snd_buf.head[0].iov_base);
1077	iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
1078	xdr_buf_from_iov(&iov, &verf_buf);
1079
1080	/* set verifier flavor*/
1081	*p++ = htonl(RPC_AUTH_GSS);
1082
1083	mic.data = (u8 *)(p + 1);
1084	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1085	if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
1086		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1087	} else if (maj_stat != 0) {
1088		printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
1089		goto out_put_ctx;
1090	}
1091	p = xdr_encode_opaque(p, NULL, mic.len);
1092	gss_put_ctx(ctx);
1093	return p;
1094out_put_ctx:
1095	gss_put_ctx(ctx);
1096	return NULL;
1097}
1098
1099static int gss_renew_cred(struct rpc_task *task)
1100{
1101	struct rpc_cred *oldcred = task->tk_msg.rpc_cred;
1102	struct gss_cred *gss_cred = container_of(oldcred,
1103						 struct gss_cred,
1104						 gc_base);
1105	struct rpc_auth *auth = oldcred->cr_auth;
1106	struct auth_cred acred = {
1107		.uid = oldcred->cr_uid,
1108		.machine_cred = gss_cred->gc_machine_cred,
1109	};
1110	struct rpc_cred *new;
1111
1112	new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
1113	if (IS_ERR(new))
1114		return PTR_ERR(new);
1115	task->tk_msg.rpc_cred = new;
1116	put_rpccred(oldcred);
1117	return 0;
1118}
1119
1120/*
1121* Refresh credentials. XXX - finish
1122*/
1123static int
1124gss_refresh(struct rpc_task *task)
1125{
1126	struct rpc_cred *cred = task->tk_msg.rpc_cred;
1127	int ret = 0;
1128
1129	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
1130			!test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
1131		ret = gss_renew_cred(task);
1132		if (ret < 0)
1133			goto out;
1134		cred = task->tk_msg.rpc_cred;
1135	}
1136
1137	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
1138		ret = gss_refresh_upcall(task);
1139out:
1140	return ret;
1141}
1142
1143/* Dummy refresh routine: used only when destroying the context */
1144static int
1145gss_refresh_null(struct rpc_task *task)
1146{
1147	return -EACCES;
1148}
1149
1150static __be32 *
1151gss_validate(struct rpc_task *task, __be32 *p)
1152{
1153	struct rpc_cred *cred = task->tk_msg.rpc_cred;
1154	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1155	__be32		seq;
1156	struct kvec	iov;
1157	struct xdr_buf	verf_buf;
1158	struct xdr_netobj mic;
1159	u32		flav,len;
1160	u32		maj_stat;
1161
1162	dprintk("RPC: %5u gss_validate\n", task->tk_pid);
1163
1164	flav = ntohl(*p++);
1165	if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
1166		goto out_bad;
1167	if (flav != RPC_AUTH_GSS)
1168		goto out_bad;
1169	seq = htonl(task->tk_rqstp->rq_seqno);
1170	iov.iov_base = &seq;
1171	iov.iov_len = sizeof(seq);
1172	xdr_buf_from_iov(&iov, &verf_buf);
1173	mic.data = (u8 *)p;
1174	mic.len = len;
1175
1176	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1177	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1178		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1179	if (maj_stat) {
1180		dprintk("RPC: %5u gss_validate: gss_verify_mic returned "
1181				"error 0x%08x\n", task->tk_pid, maj_stat);
1182		goto out_bad;
1183	}
1184	/* We leave it to unwrap to calculate au_rslack. For now we just
1185	 * calculate the length of the verifier: */
1186	cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
1187	gss_put_ctx(ctx);
1188	dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n",
1189			task->tk_pid);
1190	return p + XDR_QUADLEN(len);
1191out_bad:
1192	gss_put_ctx(ctx);
1193	dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid);
1194	return NULL;
1195}
1196
1197static inline int
1198gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1199		kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj)
1200{
1201	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1202	struct xdr_buf	integ_buf;
1203	__be32          *integ_len = NULL;
1204	struct xdr_netobj mic;
1205	u32		offset;
1206	__be32		*q;
1207	struct kvec	*iov;
1208	u32             maj_stat = 0;
1209	int		status = -EIO;
1210
1211	integ_len = p++;
1212	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1213	*p++ = htonl(rqstp->rq_seqno);
1214
1215	status = encode(rqstp, p, obj);
1216	if (status)
1217		return status;
1218
1219	if (xdr_buf_subsegment(snd_buf, &integ_buf,
1220				offset, snd_buf->len - offset))
1221		return status;
1222	*integ_len = htonl(integ_buf.len);
1223
1224	/* guess whether we're in the head or the tail: */
1225	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1226		iov = snd_buf->tail;
1227	else
1228		iov = snd_buf->head;
1229	p = iov->iov_base + iov->iov_len;
1230	mic.data = (u8 *)(p + 1);
1231
1232	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1233	status = -EIO; /* XXX? */
1234	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1235		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1236	else if (maj_stat)
1237		return status;
1238	q = xdr_encode_opaque(p, NULL, mic.len);
1239
1240	offset = (u8 *)q - (u8 *)p;
1241	iov->iov_len += offset;
1242	snd_buf->len += offset;
1243	return 0;
1244}
1245
1246static void
1247priv_release_snd_buf(struct rpc_rqst *rqstp)
1248{
1249	int i;
1250
1251	for (i=0; i < rqstp->rq_enc_pages_num; i++)
1252		__free_page(rqstp->rq_enc_pages[i]);
1253	kfree(rqstp->rq_enc_pages);
1254}
1255
1256static int
1257alloc_enc_pages(struct rpc_rqst *rqstp)
1258{
1259	struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
1260	int first, last, i;
1261
1262	if (snd_buf->page_len == 0) {
1263		rqstp->rq_enc_pages_num = 0;
1264		return 0;
1265	}
1266
1267	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1268	last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT;
1269	rqstp->rq_enc_pages_num = last - first + 1 + 1;
1270	rqstp->rq_enc_pages
1271		= kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *),
1272				GFP_NOFS);
1273	if (!rqstp->rq_enc_pages)
1274		goto out;
1275	for (i=0; i < rqstp->rq_enc_pages_num; i++) {
1276		rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS);
1277		if (rqstp->rq_enc_pages[i] == NULL)
1278			goto out_free;
1279	}
1280	rqstp->rq_release_snd_buf = priv_release_snd_buf;
1281	return 0;
1282out_free:
1283	for (i--; i >= 0; i--) {
1284		__free_page(rqstp->rq_enc_pages[i]);
1285	}
1286out:
1287	return -EAGAIN;
1288}
1289
1290static inline int
1291gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1292		kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj)
1293{
1294	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1295	u32		offset;
1296	u32             maj_stat;
1297	int		status;
1298	__be32		*opaque_len;
1299	struct page	**inpages;
1300	int		first;
1301	int		pad;
1302	struct kvec	*iov;
1303	char		*tmp;
1304
1305	opaque_len = p++;
1306	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1307	*p++ = htonl(rqstp->rq_seqno);
1308
1309	status = encode(rqstp, p, obj);
1310	if (status)
1311		return status;
1312
1313	status = alloc_enc_pages(rqstp);
1314	if (status)
1315		return status;
1316	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1317	inpages = snd_buf->pages + first;
1318	snd_buf->pages = rqstp->rq_enc_pages;
1319	snd_buf->page_base -= first << PAGE_CACHE_SHIFT;
1320	/* Give the tail its own page, in case we need extra space in the
1321	 * head when wrapping: */
1322	if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
1323		tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
1324		memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
1325		snd_buf->tail[0].iov_base = tmp;
1326	}
1327	maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
1328	/* RPC_SLACK_SPACE should prevent this ever happening: */
1329	BUG_ON(snd_buf->len > snd_buf->buflen);
1330	status = -EIO;
1331	/* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
1332	 * done anyway, so it's safe to put the request on the wire: */
1333	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1334		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1335	else if (maj_stat)
1336		return status;
1337
1338	*opaque_len = htonl(snd_buf->len - offset);
1339	/* guess whether we're in the head or the tail: */
1340	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1341		iov = snd_buf->tail;
1342	else
1343		iov = snd_buf->head;
1344	p = iov->iov_base + iov->iov_len;
1345	pad = 3 - ((snd_buf->len - offset - 1) & 3);
1346	memset(p, 0, pad);
1347	iov->iov_len += pad;
1348	snd_buf->len += pad;
1349
1350	return 0;
1351}
1352
1353static int
1354gss_wrap_req(struct rpc_task *task,
1355	     kxdrproc_t encode, void *rqstp, __be32 *p, void *obj)
1356{
1357	struct rpc_cred *cred = task->tk_msg.rpc_cred;
1358	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1359			gc_base);
1360	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1361	int             status = -EIO;
1362
1363	dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid);
1364	if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
1365		/* The spec seems a little ambiguous here, but I think that not
1366		 * wrapping context destruction requests makes the most sense.
1367		 */
1368		status = encode(rqstp, p, obj);
1369		goto out;
1370	}
1371	switch (gss_cred->gc_service) {
1372		case RPC_GSS_SVC_NONE:
1373			status = encode(rqstp, p, obj);
1374			break;
1375		case RPC_GSS_SVC_INTEGRITY:
1376			status = gss_wrap_req_integ(cred, ctx, encode,
1377								rqstp, p, obj);
1378			break;
1379		case RPC_GSS_SVC_PRIVACY:
1380			status = gss_wrap_req_priv(cred, ctx, encode,
1381					rqstp, p, obj);
1382			break;
1383	}
1384out:
1385	gss_put_ctx(ctx);
1386	dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status);
1387	return status;
1388}
1389
1390static inline int
1391gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1392		struct rpc_rqst *rqstp, __be32 **p)
1393{
1394	struct xdr_buf	*rcv_buf = &rqstp->rq_rcv_buf;
1395	struct xdr_buf integ_buf;
1396	struct xdr_netobj mic;
1397	u32 data_offset, mic_offset;
1398	u32 integ_len;
1399	u32 maj_stat;
1400	int status = -EIO;
1401
1402	integ_len = ntohl(*(*p)++);
1403	if (integ_len & 3)
1404		return status;
1405	data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1406	mic_offset = integ_len + data_offset;
1407	if (mic_offset > rcv_buf->len)
1408		return status;
1409	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1410		return status;
1411
1412	if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
1413				mic_offset - data_offset))
1414		return status;
1415
1416	if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
1417		return status;
1418
1419	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1420	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1421		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1422	if (maj_stat != GSS_S_COMPLETE)
1423		return status;
1424	return 0;
1425}
1426
1427static inline int
1428gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1429		struct rpc_rqst *rqstp, __be32 **p)
1430{
1431	struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1432	u32 offset;
1433	u32 opaque_len;
1434	u32 maj_stat;
1435	int status = -EIO;
1436
1437	opaque_len = ntohl(*(*p)++);
1438	offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1439	if (offset + opaque_len > rcv_buf->len)
1440		return status;
1441	/* remove padding: */
1442	rcv_buf->len = offset + opaque_len;
1443
1444	maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
1445	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1446		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1447	if (maj_stat != GSS_S_COMPLETE)
1448		return status;
1449	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1450		return status;
1451
1452	return 0;
1453}
1454
1455
1456static int
1457gss_unwrap_resp(struct rpc_task *task,
1458		kxdrproc_t decode, void *rqstp, __be32 *p, void *obj)
1459{
1460	struct rpc_cred *cred = task->tk_msg.rpc_cred;
1461	struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1462			gc_base);
1463	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1464	__be32		*savedp = p;
1465	struct kvec	*head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
1466	int		savedlen = head->iov_len;
1467	int             status = -EIO;
1468
1469	if (ctx->gc_proc != RPC_GSS_PROC_DATA)
1470		goto out_decode;
1471	switch (gss_cred->gc_service) {
1472		case RPC_GSS_SVC_NONE:
1473			break;
1474		case RPC_GSS_SVC_INTEGRITY:
1475			status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
1476			if (status)
1477				goto out;
1478			break;
1479		case RPC_GSS_SVC_PRIVACY:
1480			status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
1481			if (status)
1482				goto out;
1483			break;
1484	}
1485	/* take into account extra slack for integrity and privacy cases: */
1486	cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
1487						+ (savedlen - head->iov_len);
1488out_decode:
1489	status = decode(rqstp, p, obj);
1490out:
1491	gss_put_ctx(ctx);
1492	dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
1493			status);
1494	return status;
1495}
1496
1497static const struct rpc_authops authgss_ops = {
1498	.owner		= THIS_MODULE,
1499	.au_flavor	= RPC_AUTH_GSS,
1500	.au_name	= "RPCSEC_GSS",
1501	.create		= gss_create,
1502	.destroy	= gss_destroy,
1503	.lookup_cred	= gss_lookup_cred,
1504	.crcreate	= gss_create_cred
1505};
1506
1507static const struct rpc_credops gss_credops = {
1508	.cr_name	= "AUTH_GSS",
1509	.crdestroy	= gss_destroy_cred,
1510	.cr_init	= gss_cred_init,
1511	.crbind		= rpcauth_generic_bind_cred,
1512	.crmatch	= gss_match,
1513	.crmarshal	= gss_marshal,
1514	.crrefresh	= gss_refresh,
1515	.crvalidate	= gss_validate,
1516	.crwrap_req	= gss_wrap_req,
1517	.crunwrap_resp	= gss_unwrap_resp,
1518};
1519
1520static const struct rpc_credops gss_nullops = {
1521	.cr_name	= "AUTH_GSS",
1522	.crdestroy	= gss_destroy_nullcred,
1523	.crbind		= rpcauth_generic_bind_cred,
1524	.crmatch	= gss_match,
1525	.crmarshal	= gss_marshal,
1526	.crrefresh	= gss_refresh_null,
1527	.crvalidate	= gss_validate,
1528	.crwrap_req	= gss_wrap_req,
1529	.crunwrap_resp	= gss_unwrap_resp,
1530};
1531
1532static const struct rpc_pipe_ops gss_upcall_ops_v0 = {
1533	.upcall		= gss_pipe_upcall,
1534	.downcall	= gss_pipe_downcall,
1535	.destroy_msg	= gss_pipe_destroy_msg,
1536	.open_pipe	= gss_pipe_open_v0,
1537	.release_pipe	= gss_pipe_release,
1538};
1539
1540static const struct rpc_pipe_ops gss_upcall_ops_v1 = {
1541	.upcall		= gss_pipe_upcall,
1542	.downcall	= gss_pipe_downcall,
1543	.destroy_msg	= gss_pipe_destroy_msg,
1544	.open_pipe	= gss_pipe_open_v1,
1545	.release_pipe	= gss_pipe_release,
1546};
1547
1548/*
1549 * Initialize RPCSEC_GSS module
1550 */
1551static int __init init_rpcsec_gss(void)
1552{
1553	int err = 0;
1554
1555	err = rpcauth_register(&authgss_ops);
1556	if (err)
1557		goto out;
1558	err = gss_svc_init();
1559	if (err)
1560		goto out_unregister;
1561	rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
1562	return 0;
1563out_unregister:
1564	rpcauth_unregister(&authgss_ops);
1565out:
1566	return err;
1567}
1568
1569static void __exit exit_rpcsec_gss(void)
1570{
1571	gss_svc_shutdown();
1572	rpcauth_unregister(&authgss_ops);
1573	rcu_barrier(); /* Wait for completion of call_rcu()'s */
1574}
1575
1576MODULE_LICENSE("GPL");
1577module_init(init_rpcsec_gss)
1578module_exit(exit_rpcsec_gss)
1579