auth_gss.c revision 2446ab6070861aba2dd9229463ffbc40016a9f33
1/*
2 * linux/net/sunrpc/auth_gss/auth_gss.c
3 *
4 * RPCSEC_GSS client authentication.
5 *
6 *  Copyright (c) 2000 The Regents of the University of Michigan.
7 *  All rights reserved.
8 *
9 *  Dug Song       <dugsong@monkey.org>
10 *  Andy Adamson   <andros@umich.edu>
11 *
12 *  Redistribution and use in source and binary forms, with or without
13 *  modification, are permitted provided that the following conditions
14 *  are met:
15 *
16 *  1. Redistributions of source code must retain the above copyright
17 *     notice, this list of conditions and the following disclaimer.
18 *  2. Redistributions in binary form must reproduce the above copyright
19 *     notice, this list of conditions and the following disclaimer in the
20 *     documentation and/or other materials provided with the distribution.
21 *  3. Neither the name of the University nor the names of its
22 *     contributors may be used to endorse or promote products derived
23 *     from this software without specific prior written permission.
24 *
25 *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
26 *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27 *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
29 *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30 *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31 *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
32 *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 */
37
38
39#include <linux/module.h>
40#include <linux/init.h>
41#include <linux/types.h>
42#include <linux/slab.h>
43#include <linux/sched.h>
44#include <linux/pagemap.h>
45#include <linux/sunrpc/clnt.h>
46#include <linux/sunrpc/auth.h>
47#include <linux/sunrpc/auth_gss.h>
48#include <linux/sunrpc/svcauth_gss.h>
49#include <linux/sunrpc/gss_err.h>
50#include <linux/workqueue.h>
51#include <linux/sunrpc/rpc_pipe_fs.h>
52#include <linux/sunrpc/gss_api.h>
53#include <asm/uaccess.h>
54
55static const struct rpc_authops authgss_ops;
56
57static const struct rpc_credops gss_credops;
58static const struct rpc_credops gss_nullops;
59
60#define GSS_RETRY_EXPIRED 5
61static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED;
62
63#ifdef RPC_DEBUG
64# define RPCDBG_FACILITY	RPCDBG_AUTH
65#endif
66
67#define GSS_CRED_SLACK		(RPC_MAX_AUTH_SIZE * 2)
68/* length of a krb5 verifier (48), plus data added before arguments when
69 * using integrity (two 4-byte integers): */
70#define GSS_VERF_SLACK		100
71
72struct gss_auth {
73	struct kref kref;
74	struct rpc_auth rpc_auth;
75	struct gss_api_mech *mech;
76	enum rpc_gss_svc service;
77	struct rpc_clnt *client;
78	/*
79	 * There are two upcall pipes; dentry[1], named "gssd", is used
80	 * for the new text-based upcall; dentry[0] is named after the
81	 * mechanism (for example, "krb5") and exists for
82	 * backwards-compatibility with older gssd's.
83	 */
84	struct rpc_pipe *pipe[2];
85};
86
87/* pipe_version >= 0 if and only if someone has a pipe open. */
88static int pipe_version = -1;
89static atomic_t pipe_users = ATOMIC_INIT(0);
90static DEFINE_SPINLOCK(pipe_version_lock);
91static struct rpc_wait_queue pipe_version_rpc_waitqueue;
92static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
93
94static void gss_free_ctx(struct gss_cl_ctx *);
95static const struct rpc_pipe_ops gss_upcall_ops_v0;
96static const struct rpc_pipe_ops gss_upcall_ops_v1;
97
98static inline struct gss_cl_ctx *
99gss_get_ctx(struct gss_cl_ctx *ctx)
100{
101	atomic_inc(&ctx->count);
102	return ctx;
103}
104
105static inline void
106gss_put_ctx(struct gss_cl_ctx *ctx)
107{
108	if (atomic_dec_and_test(&ctx->count))
109		gss_free_ctx(ctx);
110}
111
112/* gss_cred_set_ctx:
113 * called by gss_upcall_callback and gss_create_upcall in order
114 * to set the gss context. The actual exchange of an old context
115 * and a new one is protected by the pipe->lock.
116 */
117static void
118gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
119{
120	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
121
122	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
123		return;
124	gss_get_ctx(ctx);
125	rcu_assign_pointer(gss_cred->gc_ctx, ctx);
126	set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
127	smp_mb__before_clear_bit();
128	clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
129}
130
131static const void *
132simple_get_bytes(const void *p, const void *end, void *res, size_t len)
133{
134	const void *q = (const void *)((const char *)p + len);
135	if (unlikely(q > end || q < p))
136		return ERR_PTR(-EFAULT);
137	memcpy(res, p, len);
138	return q;
139}
140
141static inline const void *
142simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
143{
144	const void *q;
145	unsigned int len;
146
147	p = simple_get_bytes(p, end, &len, sizeof(len));
148	if (IS_ERR(p))
149		return p;
150	q = (const void *)((const char *)p + len);
151	if (unlikely(q > end || q < p))
152		return ERR_PTR(-EFAULT);
153	dest->data = kmemdup(p, len, GFP_NOFS);
154	if (unlikely(dest->data == NULL))
155		return ERR_PTR(-ENOMEM);
156	dest->len = len;
157	return q;
158}
159
160static struct gss_cl_ctx *
161gss_cred_get_ctx(struct rpc_cred *cred)
162{
163	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
164	struct gss_cl_ctx *ctx = NULL;
165
166	rcu_read_lock();
167	if (gss_cred->gc_ctx)
168		ctx = gss_get_ctx(gss_cred->gc_ctx);
169	rcu_read_unlock();
170	return ctx;
171}
172
173static struct gss_cl_ctx *
174gss_alloc_context(void)
175{
176	struct gss_cl_ctx *ctx;
177
178	ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
179	if (ctx != NULL) {
180		ctx->gc_proc = RPC_GSS_PROC_DATA;
181		ctx->gc_seq = 1;	/* NetApp 6.4R1 doesn't accept seq. no. 0 */
182		spin_lock_init(&ctx->gc_seq_lock);
183		atomic_set(&ctx->count,1);
184	}
185	return ctx;
186}
187
188#define GSSD_MIN_TIMEOUT (60 * 60)
189static const void *
190gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
191{
192	const void *q;
193	unsigned int seclen;
194	unsigned int timeout;
195	u32 window_size;
196	int ret;
197
198	/* First unsigned int gives the lifetime (in seconds) of the cred */
199	p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
200	if (IS_ERR(p))
201		goto err;
202	if (timeout == 0)
203		timeout = GSSD_MIN_TIMEOUT;
204	ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4;
205	/* Sequence number window. Determines the maximum number of simultaneous requests */
206	p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
207	if (IS_ERR(p))
208		goto err;
209	ctx->gc_win = window_size;
210	/* gssd signals an error by passing ctx->gc_win = 0: */
211	if (ctx->gc_win == 0) {
212		/*
213		 * in which case, p points to an error code. Anything other
214		 * than -EKEYEXPIRED gets converted to -EACCES.
215		 */
216		p = simple_get_bytes(p, end, &ret, sizeof(ret));
217		if (!IS_ERR(p))
218			p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) :
219						    ERR_PTR(-EACCES);
220		goto err;
221	}
222	/* copy the opaque wire context */
223	p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
224	if (IS_ERR(p))
225		goto err;
226	/* import the opaque security context */
227	p  = simple_get_bytes(p, end, &seclen, sizeof(seclen));
228	if (IS_ERR(p))
229		goto err;
230	q = (const void *)((const char *)p + seclen);
231	if (unlikely(q > end || q < p)) {
232		p = ERR_PTR(-EFAULT);
233		goto err;
234	}
235	ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, GFP_NOFS);
236	if (ret < 0) {
237		p = ERR_PTR(ret);
238		goto err;
239	}
240	return q;
241err:
242	dprintk("RPC:       gss_fill_context returning %ld\n", -PTR_ERR(p));
243	return p;
244}
245
246#define UPCALL_BUF_LEN 128
247
248struct gss_upcall_msg {
249	atomic_t count;
250	uid_t	uid;
251	struct rpc_pipe_msg msg;
252	struct list_head list;
253	struct gss_auth *auth;
254	struct rpc_pipe *pipe;
255	struct rpc_wait_queue rpc_waitqueue;
256	wait_queue_head_t waitqueue;
257	struct gss_cl_ctx *ctx;
258	char databuf[UPCALL_BUF_LEN];
259};
260
261static int get_pipe_version(void)
262{
263	int ret;
264
265	spin_lock(&pipe_version_lock);
266	if (pipe_version >= 0) {
267		atomic_inc(&pipe_users);
268		ret = pipe_version;
269	} else
270		ret = -EAGAIN;
271	spin_unlock(&pipe_version_lock);
272	return ret;
273}
274
275static void put_pipe_version(void)
276{
277	if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
278		pipe_version = -1;
279		spin_unlock(&pipe_version_lock);
280	}
281}
282
283static void
284gss_release_msg(struct gss_upcall_msg *gss_msg)
285{
286	if (!atomic_dec_and_test(&gss_msg->count))
287		return;
288	put_pipe_version();
289	BUG_ON(!list_empty(&gss_msg->list));
290	if (gss_msg->ctx != NULL)
291		gss_put_ctx(gss_msg->ctx);
292	rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
293	kfree(gss_msg);
294}
295
296static struct gss_upcall_msg *
297__gss_find_upcall(struct rpc_pipe *pipe, uid_t uid)
298{
299	struct gss_upcall_msg *pos;
300	list_for_each_entry(pos, &pipe->in_downcall, list) {
301		if (pos->uid != uid)
302			continue;
303		atomic_inc(&pos->count);
304		dprintk("RPC:       gss_find_upcall found msg %p\n", pos);
305		return pos;
306	}
307	dprintk("RPC:       gss_find_upcall found nothing\n");
308	return NULL;
309}
310
311/* Try to add an upcall to the pipefs queue.
312 * If an upcall owned by our uid already exists, then we return a reference
313 * to that upcall instead of adding the new upcall.
314 */
315static inline struct gss_upcall_msg *
316gss_add_msg(struct gss_upcall_msg *gss_msg)
317{
318	struct rpc_pipe *pipe = gss_msg->pipe;
319	struct gss_upcall_msg *old;
320
321	spin_lock(&pipe->lock);
322	old = __gss_find_upcall(pipe, gss_msg->uid);
323	if (old == NULL) {
324		atomic_inc(&gss_msg->count);
325		list_add(&gss_msg->list, &pipe->in_downcall);
326	} else
327		gss_msg = old;
328	spin_unlock(&pipe->lock);
329	return gss_msg;
330}
331
332static void
333__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
334{
335	list_del_init(&gss_msg->list);
336	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
337	wake_up_all(&gss_msg->waitqueue);
338	atomic_dec(&gss_msg->count);
339}
340
341static void
342gss_unhash_msg(struct gss_upcall_msg *gss_msg)
343{
344	struct rpc_pipe *pipe = gss_msg->pipe;
345
346	if (list_empty(&gss_msg->list))
347		return;
348	spin_lock(&pipe->lock);
349	if (!list_empty(&gss_msg->list))
350		__gss_unhash_msg(gss_msg);
351	spin_unlock(&pipe->lock);
352}
353
354static void
355gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg)
356{
357	switch (gss_msg->msg.errno) {
358	case 0:
359		if (gss_msg->ctx == NULL)
360			break;
361		clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
362		gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx);
363		break;
364	case -EKEYEXPIRED:
365		set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
366	}
367	gss_cred->gc_upcall_timestamp = jiffies;
368	gss_cred->gc_upcall = NULL;
369	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
370}
371
372static void
373gss_upcall_callback(struct rpc_task *task)
374{
375	struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred,
376			struct gss_cred, gc_base);
377	struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
378	struct rpc_pipe *pipe = gss_msg->pipe;
379
380	spin_lock(&pipe->lock);
381	gss_handle_downcall_result(gss_cred, gss_msg);
382	spin_unlock(&pipe->lock);
383	task->tk_status = gss_msg->msg.errno;
384	gss_release_msg(gss_msg);
385}
386
387static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
388{
389	gss_msg->msg.data = &gss_msg->uid;
390	gss_msg->msg.len = sizeof(gss_msg->uid);
391}
392
393static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
394				struct rpc_clnt *clnt,
395				const char *service_name)
396{
397	struct gss_api_mech *mech = gss_msg->auth->mech;
398	char *p = gss_msg->databuf;
399	int len = 0;
400
401	gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
402				   mech->gm_name,
403				   gss_msg->uid);
404	p += gss_msg->msg.len;
405	if (clnt->cl_principal) {
406		len = sprintf(p, "target=%s ", clnt->cl_principal);
407		p += len;
408		gss_msg->msg.len += len;
409	}
410	if (service_name != NULL) {
411		len = sprintf(p, "service=%s ", service_name);
412		p += len;
413		gss_msg->msg.len += len;
414	}
415	if (mech->gm_upcall_enctypes) {
416		len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes);
417		p += len;
418		gss_msg->msg.len += len;
419	}
420	len = sprintf(p, "\n");
421	gss_msg->msg.len += len;
422
423	gss_msg->msg.data = gss_msg->databuf;
424	BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
425}
426
427static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
428				struct rpc_clnt *clnt,
429				const char *service_name)
430{
431	if (pipe_version == 0)
432		gss_encode_v0_msg(gss_msg);
433	else /* pipe_version == 1 */
434		gss_encode_v1_msg(gss_msg, clnt, service_name);
435}
436
437static struct gss_upcall_msg *
438gss_alloc_msg(struct gss_auth *gss_auth, struct rpc_clnt *clnt,
439		uid_t uid, const char *service_name)
440{
441	struct gss_upcall_msg *gss_msg;
442	int vers;
443
444	gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
445	if (gss_msg == NULL)
446		return ERR_PTR(-ENOMEM);
447	vers = get_pipe_version();
448	if (vers < 0) {
449		kfree(gss_msg);
450		return ERR_PTR(vers);
451	}
452	gss_msg->pipe = gss_auth->pipe[vers];
453	INIT_LIST_HEAD(&gss_msg->list);
454	rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
455	init_waitqueue_head(&gss_msg->waitqueue);
456	atomic_set(&gss_msg->count, 1);
457	gss_msg->uid = uid;
458	gss_msg->auth = gss_auth;
459	gss_encode_msg(gss_msg, clnt, service_name);
460	return gss_msg;
461}
462
463static struct gss_upcall_msg *
464gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
465{
466	struct gss_cred *gss_cred = container_of(cred,
467			struct gss_cred, gc_base);
468	struct gss_upcall_msg *gss_new, *gss_msg;
469	uid_t uid = cred->cr_uid;
470
471	gss_new = gss_alloc_msg(gss_auth, clnt, uid, gss_cred->gc_principal);
472	if (IS_ERR(gss_new))
473		return gss_new;
474	gss_msg = gss_add_msg(gss_new);
475	if (gss_msg == gss_new) {
476		int res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg);
477		if (res) {
478			gss_unhash_msg(gss_new);
479			gss_msg = ERR_PTR(res);
480		}
481	} else
482		gss_release_msg(gss_new);
483	return gss_msg;
484}
485
486static void warn_gssd(void)
487{
488	static unsigned long ratelimit;
489	unsigned long now = jiffies;
490
491	if (time_after(now, ratelimit)) {
492		printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
493				"Please check user daemon is running.\n");
494		ratelimit = now + 15*HZ;
495	}
496}
497
498static inline int
499gss_refresh_upcall(struct rpc_task *task)
500{
501	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
502	struct gss_auth *gss_auth = container_of(cred->cr_auth,
503			struct gss_auth, rpc_auth);
504	struct gss_cred *gss_cred = container_of(cred,
505			struct gss_cred, gc_base);
506	struct gss_upcall_msg *gss_msg;
507	struct rpc_pipe *pipe;
508	int err = 0;
509
510	dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid,
511								cred->cr_uid);
512	gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
513	if (PTR_ERR(gss_msg) == -EAGAIN) {
514		/* XXX: warning on the first, under the assumption we
515		 * shouldn't normally hit this case on a refresh. */
516		warn_gssd();
517		task->tk_timeout = 15*HZ;
518		rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
519		return -EAGAIN;
520	}
521	if (IS_ERR(gss_msg)) {
522		err = PTR_ERR(gss_msg);
523		goto out;
524	}
525	pipe = gss_msg->pipe;
526	spin_lock(&pipe->lock);
527	if (gss_cred->gc_upcall != NULL)
528		rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
529	else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
530		task->tk_timeout = 0;
531		gss_cred->gc_upcall = gss_msg;
532		/* gss_upcall_callback will release the reference to gss_upcall_msg */
533		atomic_inc(&gss_msg->count);
534		rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
535	} else {
536		gss_handle_downcall_result(gss_cred, gss_msg);
537		err = gss_msg->msg.errno;
538	}
539	spin_unlock(&pipe->lock);
540	gss_release_msg(gss_msg);
541out:
542	dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n",
543			task->tk_pid, cred->cr_uid, err);
544	return err;
545}
546
547static inline int
548gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
549{
550	struct rpc_pipe *pipe;
551	struct rpc_cred *cred = &gss_cred->gc_base;
552	struct gss_upcall_msg *gss_msg;
553	DEFINE_WAIT(wait);
554	int err = 0;
555
556	dprintk("RPC:       gss_upcall for uid %u\n", cred->cr_uid);
557retry:
558	gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
559	if (PTR_ERR(gss_msg) == -EAGAIN) {
560		err = wait_event_interruptible_timeout(pipe_version_waitqueue,
561				pipe_version >= 0, 15*HZ);
562		if (pipe_version < 0) {
563			warn_gssd();
564			err = -EACCES;
565		}
566		if (err)
567			goto out;
568		goto retry;
569	}
570	if (IS_ERR(gss_msg)) {
571		err = PTR_ERR(gss_msg);
572		goto out;
573	}
574	pipe = gss_msg->pipe;
575	for (;;) {
576		prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE);
577		spin_lock(&pipe->lock);
578		if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
579			break;
580		}
581		spin_unlock(&pipe->lock);
582		if (fatal_signal_pending(current)) {
583			err = -ERESTARTSYS;
584			goto out_intr;
585		}
586		schedule();
587	}
588	if (gss_msg->ctx)
589		gss_cred_set_ctx(cred, gss_msg->ctx);
590	else
591		err = gss_msg->msg.errno;
592	spin_unlock(&pipe->lock);
593out_intr:
594	finish_wait(&gss_msg->waitqueue, &wait);
595	gss_release_msg(gss_msg);
596out:
597	dprintk("RPC:       gss_create_upcall for uid %u result %d\n",
598			cred->cr_uid, err);
599	return err;
600}
601
602#define MSG_BUF_MAXSIZE 1024
603
604static ssize_t
605gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
606{
607	const void *p, *end;
608	void *buf;
609	struct gss_upcall_msg *gss_msg;
610	struct rpc_pipe *pipe = RPC_I(filp->f_dentry->d_inode)->pipe;
611	struct gss_cl_ctx *ctx;
612	uid_t uid;
613	ssize_t err = -EFBIG;
614
615	if (mlen > MSG_BUF_MAXSIZE)
616		goto out;
617	err = -ENOMEM;
618	buf = kmalloc(mlen, GFP_NOFS);
619	if (!buf)
620		goto out;
621
622	err = -EFAULT;
623	if (copy_from_user(buf, src, mlen))
624		goto err;
625
626	end = (const void *)((char *)buf + mlen);
627	p = simple_get_bytes(buf, end, &uid, sizeof(uid));
628	if (IS_ERR(p)) {
629		err = PTR_ERR(p);
630		goto err;
631	}
632
633	err = -ENOMEM;
634	ctx = gss_alloc_context();
635	if (ctx == NULL)
636		goto err;
637
638	err = -ENOENT;
639	/* Find a matching upcall */
640	spin_lock(&pipe->lock);
641	gss_msg = __gss_find_upcall(pipe, uid);
642	if (gss_msg == NULL) {
643		spin_unlock(&pipe->lock);
644		goto err_put_ctx;
645	}
646	list_del_init(&gss_msg->list);
647	spin_unlock(&pipe->lock);
648
649	p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
650	if (IS_ERR(p)) {
651		err = PTR_ERR(p);
652		switch (err) {
653		case -EACCES:
654		case -EKEYEXPIRED:
655			gss_msg->msg.errno = err;
656			err = mlen;
657			break;
658		case -EFAULT:
659		case -ENOMEM:
660		case -EINVAL:
661		case -ENOSYS:
662			gss_msg->msg.errno = -EAGAIN;
663			break;
664		default:
665			printk(KERN_CRIT "%s: bad return from "
666				"gss_fill_context: %zd\n", __func__, err);
667			BUG();
668		}
669		goto err_release_msg;
670	}
671	gss_msg->ctx = gss_get_ctx(ctx);
672	err = mlen;
673
674err_release_msg:
675	spin_lock(&pipe->lock);
676	__gss_unhash_msg(gss_msg);
677	spin_unlock(&pipe->lock);
678	gss_release_msg(gss_msg);
679err_put_ctx:
680	gss_put_ctx(ctx);
681err:
682	kfree(buf);
683out:
684	dprintk("RPC:       gss_pipe_downcall returning %Zd\n", err);
685	return err;
686}
687
688static int gss_pipe_open(struct inode *inode, int new_version)
689{
690	int ret = 0;
691
692	spin_lock(&pipe_version_lock);
693	if (pipe_version < 0) {
694		/* First open of any gss pipe determines the version: */
695		pipe_version = new_version;
696		rpc_wake_up(&pipe_version_rpc_waitqueue);
697		wake_up(&pipe_version_waitqueue);
698	} else if (pipe_version != new_version) {
699		/* Trying to open a pipe of a different version */
700		ret = -EBUSY;
701		goto out;
702	}
703	atomic_inc(&pipe_users);
704out:
705	spin_unlock(&pipe_version_lock);
706	return ret;
707
708}
709
710static int gss_pipe_open_v0(struct inode *inode)
711{
712	return gss_pipe_open(inode, 0);
713}
714
715static int gss_pipe_open_v1(struct inode *inode)
716{
717	return gss_pipe_open(inode, 1);
718}
719
720static void
721gss_pipe_release(struct inode *inode)
722{
723	struct rpc_pipe *pipe = RPC_I(inode)->pipe;
724	struct gss_upcall_msg *gss_msg;
725
726restart:
727	spin_lock(&pipe->lock);
728	list_for_each_entry(gss_msg, &pipe->in_downcall, list) {
729
730		if (!list_empty(&gss_msg->msg.list))
731			continue;
732		gss_msg->msg.errno = -EPIPE;
733		atomic_inc(&gss_msg->count);
734		__gss_unhash_msg(gss_msg);
735		spin_unlock(&pipe->lock);
736		gss_release_msg(gss_msg);
737		goto restart;
738	}
739	spin_unlock(&pipe->lock);
740
741	put_pipe_version();
742}
743
744static void
745gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
746{
747	struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
748
749	if (msg->errno < 0) {
750		dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
751				gss_msg);
752		atomic_inc(&gss_msg->count);
753		gss_unhash_msg(gss_msg);
754		if (msg->errno == -ETIMEDOUT)
755			warn_gssd();
756		gss_release_msg(gss_msg);
757	}
758}
759
760static void gss_pipes_dentries_destroy(struct rpc_auth *auth)
761{
762	struct gss_auth *gss_auth;
763
764	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
765	if (gss_auth->pipe[0]->dentry)
766		rpc_unlink(gss_auth->pipe[0]->dentry);
767	if (gss_auth->pipe[1]->dentry)
768		rpc_unlink(gss_auth->pipe[1]->dentry);
769}
770
771static int gss_pipes_dentries_create(struct rpc_auth *auth)
772{
773	int err;
774	struct gss_auth *gss_auth;
775	struct rpc_clnt *clnt;
776
777	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
778	clnt = gss_auth->client;
779
780	gss_auth->pipe[1]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
781						      "gssd",
782						      clnt, gss_auth->pipe[1]);
783	if (IS_ERR(gss_auth->pipe[1]->dentry))
784		return PTR_ERR(gss_auth->pipe[1]->dentry);
785	gss_auth->pipe[0]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
786						      gss_auth->mech->gm_name,
787						      clnt, gss_auth->pipe[0]);
788	if (IS_ERR(gss_auth->pipe[0]->dentry)) {
789		err = PTR_ERR(gss_auth->pipe[0]->dentry);
790		goto err_unlink_pipe_1;
791	}
792	return 0;
793
794err_unlink_pipe_1:
795	rpc_unlink(gss_auth->pipe[1]->dentry);
796	return err;
797}
798
799static void gss_pipes_dentries_destroy_net(struct rpc_clnt *clnt,
800					   struct rpc_auth *auth)
801{
802	struct net *net = rpc_net_ns(clnt);
803	struct super_block *sb;
804
805	sb = rpc_get_sb_net(net);
806	if (sb) {
807		if (clnt->cl_dentry)
808			gss_pipes_dentries_destroy(auth);
809		rpc_put_sb_net(net);
810	}
811}
812
813static int gss_pipes_dentries_create_net(struct rpc_clnt *clnt,
814					 struct rpc_auth *auth)
815{
816	struct net *net = rpc_net_ns(clnt);
817	struct super_block *sb;
818	int err = 0;
819
820	sb = rpc_get_sb_net(net);
821	if (sb) {
822		if (clnt->cl_dentry)
823			err = gss_pipes_dentries_create(auth);
824		rpc_put_sb_net(net);
825	}
826	return err;
827}
828
829/*
830 * NOTE: we have the opportunity to use different
831 * parameters based on the input flavor (which must be a pseudoflavor)
832 */
833static struct rpc_auth *
834gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
835{
836	struct gss_auth *gss_auth;
837	struct rpc_auth * auth;
838	int err = -ENOMEM; /* XXX? */
839
840	dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
841
842	if (!try_module_get(THIS_MODULE))
843		return ERR_PTR(err);
844	if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
845		goto out_dec;
846	gss_auth->client = clnt;
847	err = -EINVAL;
848	gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
849	if (!gss_auth->mech) {
850		printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n",
851				__func__, flavor);
852		goto err_free;
853	}
854	gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
855	if (gss_auth->service == 0)
856		goto err_put_mech;
857	auth = &gss_auth->rpc_auth;
858	auth->au_cslack = GSS_CRED_SLACK >> 2;
859	auth->au_rslack = GSS_VERF_SLACK >> 2;
860	auth->au_ops = &authgss_ops;
861	auth->au_flavor = flavor;
862	atomic_set(&auth->au_count, 1);
863	kref_init(&gss_auth->kref);
864
865	/*
866	 * Note: if we created the old pipe first, then someone who
867	 * examined the directory at the right moment might conclude
868	 * that we supported only the old pipe.  So we instead create
869	 * the new pipe first.
870	 */
871	gss_auth->pipe[1] = rpc_mkpipe_data(&gss_upcall_ops_v1,
872					    RPC_PIPE_WAIT_FOR_OPEN);
873	if (IS_ERR(gss_auth->pipe[1])) {
874		err = PTR_ERR(gss_auth->pipe[1]);
875		goto err_put_mech;
876	}
877
878	gss_auth->pipe[0] = rpc_mkpipe_data(&gss_upcall_ops_v0,
879					    RPC_PIPE_WAIT_FOR_OPEN);
880	if (IS_ERR(gss_auth->pipe[0])) {
881		err = PTR_ERR(gss_auth->pipe[0]);
882		goto err_destroy_pipe_1;
883	}
884	err = gss_pipes_dentries_create_net(clnt, auth);
885	if (err)
886		goto err_destroy_pipe_0;
887	err = rpcauth_init_credcache(auth);
888	if (err)
889		goto err_unlink_pipes;
890
891	return auth;
892err_unlink_pipes:
893	gss_pipes_dentries_destroy_net(clnt, auth);
894err_destroy_pipe_0:
895	rpc_destroy_pipe_data(gss_auth->pipe[0]);
896err_destroy_pipe_1:
897	rpc_destroy_pipe_data(gss_auth->pipe[1]);
898err_put_mech:
899	gss_mech_put(gss_auth->mech);
900err_free:
901	kfree(gss_auth);
902out_dec:
903	module_put(THIS_MODULE);
904	return ERR_PTR(err);
905}
906
907static void
908gss_free(struct gss_auth *gss_auth)
909{
910	gss_pipes_dentries_destroy_net(gss_auth->client, &gss_auth->rpc_auth);
911	rpc_destroy_pipe_data(gss_auth->pipe[0]);
912	rpc_destroy_pipe_data(gss_auth->pipe[1]);
913	gss_mech_put(gss_auth->mech);
914
915	kfree(gss_auth);
916	module_put(THIS_MODULE);
917}
918
919static void
920gss_free_callback(struct kref *kref)
921{
922	struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
923
924	gss_free(gss_auth);
925}
926
927static void
928gss_destroy(struct rpc_auth *auth)
929{
930	struct gss_auth *gss_auth;
931
932	dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
933			auth, auth->au_flavor);
934
935	rpcauth_destroy_credcache(auth);
936
937	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
938	kref_put(&gss_auth->kref, gss_free_callback);
939}
940
941/*
942 * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call
943 * to the server with the GSS control procedure field set to
944 * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
945 * all RPCSEC_GSS state associated with that context.
946 */
947static int
948gss_destroying_context(struct rpc_cred *cred)
949{
950	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
951	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
952	struct rpc_task *task;
953
954	if (gss_cred->gc_ctx == NULL ||
955	    test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
956		return 0;
957
958	gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
959	cred->cr_ops = &gss_nullops;
960
961	/* Take a reference to ensure the cred will be destroyed either
962	 * by the RPC call or by the put_rpccred() below */
963	get_rpccred(cred);
964
965	task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT);
966	if (!IS_ERR(task))
967		rpc_put_task(task);
968
969	put_rpccred(cred);
970	return 1;
971}
972
973/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
974 * to create a new cred or context, so they check that things have been
975 * allocated before freeing them. */
976static void
977gss_do_free_ctx(struct gss_cl_ctx *ctx)
978{
979	dprintk("RPC:       gss_free_ctx\n");
980
981	gss_delete_sec_context(&ctx->gc_gss_ctx);
982	kfree(ctx->gc_wire_ctx.data);
983	kfree(ctx);
984}
985
986static void
987gss_free_ctx_callback(struct rcu_head *head)
988{
989	struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
990	gss_do_free_ctx(ctx);
991}
992
993static void
994gss_free_ctx(struct gss_cl_ctx *ctx)
995{
996	call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
997}
998
999static void
1000gss_free_cred(struct gss_cred *gss_cred)
1001{
1002	dprintk("RPC:       gss_free_cred %p\n", gss_cred);
1003	kfree(gss_cred);
1004}
1005
1006static void
1007gss_free_cred_callback(struct rcu_head *head)
1008{
1009	struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
1010	gss_free_cred(gss_cred);
1011}
1012
1013static void
1014gss_destroy_nullcred(struct rpc_cred *cred)
1015{
1016	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
1017	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
1018	struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
1019
1020	RCU_INIT_POINTER(gss_cred->gc_ctx, NULL);
1021	call_rcu(&cred->cr_rcu, gss_free_cred_callback);
1022	if (ctx)
1023		gss_put_ctx(ctx);
1024	kref_put(&gss_auth->kref, gss_free_callback);
1025}
1026
1027static void
1028gss_destroy_cred(struct rpc_cred *cred)
1029{
1030
1031	if (gss_destroying_context(cred))
1032		return;
1033	gss_destroy_nullcred(cred);
1034}
1035
1036/*
1037 * Lookup RPCSEC_GSS cred for the current process
1038 */
1039static struct rpc_cred *
1040gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1041{
1042	return rpcauth_lookup_credcache(auth, acred, flags);
1043}
1044
1045static struct rpc_cred *
1046gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1047{
1048	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1049	struct gss_cred	*cred = NULL;
1050	int err = -ENOMEM;
1051
1052	dprintk("RPC:       gss_create_cred for uid %d, flavor %d\n",
1053		acred->uid, auth->au_flavor);
1054
1055	if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
1056		goto out_err;
1057
1058	rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
1059	/*
1060	 * Note: in order to force a call to call_refresh(), we deliberately
1061	 * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
1062	 */
1063	cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
1064	cred->gc_service = gss_auth->service;
1065	cred->gc_principal = NULL;
1066	if (acred->machine_cred)
1067		cred->gc_principal = acred->principal;
1068	kref_get(&gss_auth->kref);
1069	return &cred->gc_base;
1070
1071out_err:
1072	dprintk("RPC:       gss_create_cred failed with error %d\n", err);
1073	return ERR_PTR(err);
1074}
1075
1076static int
1077gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
1078{
1079	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1080	struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
1081	int err;
1082
1083	do {
1084		err = gss_create_upcall(gss_auth, gss_cred);
1085	} while (err == -EAGAIN);
1086	return err;
1087}
1088
1089static int
1090gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
1091{
1092	struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
1093
1094	if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
1095		goto out;
1096	/* Don't match with creds that have expired. */
1097	if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
1098		return 0;
1099	if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
1100		return 0;
1101out:
1102	if (acred->principal != NULL) {
1103		if (gss_cred->gc_principal == NULL)
1104			return 0;
1105		return strcmp(acred->principal, gss_cred->gc_principal) == 0;
1106	}
1107	if (gss_cred->gc_principal != NULL)
1108		return 0;
1109	return rc->cr_uid == acred->uid;
1110}
1111
1112/*
1113* Marshal credentials.
1114* Maybe we should keep a cached credential for performance reasons.
1115*/
1116static __be32 *
1117gss_marshal(struct rpc_task *task, __be32 *p)
1118{
1119	struct rpc_rqst *req = task->tk_rqstp;
1120	struct rpc_cred *cred = req->rq_cred;
1121	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1122						 gc_base);
1123	struct gss_cl_ctx	*ctx = gss_cred_get_ctx(cred);
1124	__be32		*cred_len;
1125	u32             maj_stat = 0;
1126	struct xdr_netobj mic;
1127	struct kvec	iov;
1128	struct xdr_buf	verf_buf;
1129
1130	dprintk("RPC: %5u gss_marshal\n", task->tk_pid);
1131
1132	*p++ = htonl(RPC_AUTH_GSS);
1133	cred_len = p++;
1134
1135	spin_lock(&ctx->gc_seq_lock);
1136	req->rq_seqno = ctx->gc_seq++;
1137	spin_unlock(&ctx->gc_seq_lock);
1138
1139	*p++ = htonl((u32) RPC_GSS_VERSION);
1140	*p++ = htonl((u32) ctx->gc_proc);
1141	*p++ = htonl((u32) req->rq_seqno);
1142	*p++ = htonl((u32) gss_cred->gc_service);
1143	p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
1144	*cred_len = htonl((p - (cred_len + 1)) << 2);
1145
1146	/* We compute the checksum for the verifier over the xdr-encoded bytes
1147	 * starting with the xid and ending at the end of the credential: */
1148	iov.iov_base = xprt_skip_transport_header(task->tk_xprt,
1149					req->rq_snd_buf.head[0].iov_base);
1150	iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
1151	xdr_buf_from_iov(&iov, &verf_buf);
1152
1153	/* set verifier flavor*/
1154	*p++ = htonl(RPC_AUTH_GSS);
1155
1156	mic.data = (u8 *)(p + 1);
1157	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1158	if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
1159		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1160	} else if (maj_stat != 0) {
1161		printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
1162		goto out_put_ctx;
1163	}
1164	p = xdr_encode_opaque(p, NULL, mic.len);
1165	gss_put_ctx(ctx);
1166	return p;
1167out_put_ctx:
1168	gss_put_ctx(ctx);
1169	return NULL;
1170}
1171
1172static int gss_renew_cred(struct rpc_task *task)
1173{
1174	struct rpc_cred *oldcred = task->tk_rqstp->rq_cred;
1175	struct gss_cred *gss_cred = container_of(oldcred,
1176						 struct gss_cred,
1177						 gc_base);
1178	struct rpc_auth *auth = oldcred->cr_auth;
1179	struct auth_cred acred = {
1180		.uid = oldcred->cr_uid,
1181		.principal = gss_cred->gc_principal,
1182		.machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0),
1183	};
1184	struct rpc_cred *new;
1185
1186	new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
1187	if (IS_ERR(new))
1188		return PTR_ERR(new);
1189	task->tk_rqstp->rq_cred = new;
1190	put_rpccred(oldcred);
1191	return 0;
1192}
1193
1194static int gss_cred_is_negative_entry(struct rpc_cred *cred)
1195{
1196	if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) {
1197		unsigned long now = jiffies;
1198		unsigned long begin, expire;
1199		struct gss_cred *gss_cred;
1200
1201		gss_cred = container_of(cred, struct gss_cred, gc_base);
1202		begin = gss_cred->gc_upcall_timestamp;
1203		expire = begin + gss_expired_cred_retry_delay * HZ;
1204
1205		if (time_in_range_open(now, begin, expire))
1206			return 1;
1207	}
1208	return 0;
1209}
1210
1211/*
1212* Refresh credentials. XXX - finish
1213*/
1214static int
1215gss_refresh(struct rpc_task *task)
1216{
1217	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1218	int ret = 0;
1219
1220	if (gss_cred_is_negative_entry(cred))
1221		return -EKEYEXPIRED;
1222
1223	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
1224			!test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
1225		ret = gss_renew_cred(task);
1226		if (ret < 0)
1227			goto out;
1228		cred = task->tk_rqstp->rq_cred;
1229	}
1230
1231	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
1232		ret = gss_refresh_upcall(task);
1233out:
1234	return ret;
1235}
1236
1237/* Dummy refresh routine: used only when destroying the context */
1238static int
1239gss_refresh_null(struct rpc_task *task)
1240{
1241	return -EACCES;
1242}
1243
1244static __be32 *
1245gss_validate(struct rpc_task *task, __be32 *p)
1246{
1247	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1248	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1249	__be32		seq;
1250	struct kvec	iov;
1251	struct xdr_buf	verf_buf;
1252	struct xdr_netobj mic;
1253	u32		flav,len;
1254	u32		maj_stat;
1255
1256	dprintk("RPC: %5u gss_validate\n", task->tk_pid);
1257
1258	flav = ntohl(*p++);
1259	if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
1260		goto out_bad;
1261	if (flav != RPC_AUTH_GSS)
1262		goto out_bad;
1263	seq = htonl(task->tk_rqstp->rq_seqno);
1264	iov.iov_base = &seq;
1265	iov.iov_len = sizeof(seq);
1266	xdr_buf_from_iov(&iov, &verf_buf);
1267	mic.data = (u8 *)p;
1268	mic.len = len;
1269
1270	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1271	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1272		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1273	if (maj_stat) {
1274		dprintk("RPC: %5u gss_validate: gss_verify_mic returned "
1275				"error 0x%08x\n", task->tk_pid, maj_stat);
1276		goto out_bad;
1277	}
1278	/* We leave it to unwrap to calculate au_rslack. For now we just
1279	 * calculate the length of the verifier: */
1280	cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
1281	gss_put_ctx(ctx);
1282	dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n",
1283			task->tk_pid);
1284	return p + XDR_QUADLEN(len);
1285out_bad:
1286	gss_put_ctx(ctx);
1287	dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid);
1288	return NULL;
1289}
1290
1291static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
1292				__be32 *p, void *obj)
1293{
1294	struct xdr_stream xdr;
1295
1296	xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p);
1297	encode(rqstp, &xdr, obj);
1298}
1299
1300static inline int
1301gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1302		   kxdreproc_t encode, struct rpc_rqst *rqstp,
1303		   __be32 *p, void *obj)
1304{
1305	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1306	struct xdr_buf	integ_buf;
1307	__be32          *integ_len = NULL;
1308	struct xdr_netobj mic;
1309	u32		offset;
1310	__be32		*q;
1311	struct kvec	*iov;
1312	u32             maj_stat = 0;
1313	int		status = -EIO;
1314
1315	integ_len = p++;
1316	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1317	*p++ = htonl(rqstp->rq_seqno);
1318
1319	gss_wrap_req_encode(encode, rqstp, p, obj);
1320
1321	if (xdr_buf_subsegment(snd_buf, &integ_buf,
1322				offset, snd_buf->len - offset))
1323		return status;
1324	*integ_len = htonl(integ_buf.len);
1325
1326	/* guess whether we're in the head or the tail: */
1327	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1328		iov = snd_buf->tail;
1329	else
1330		iov = snd_buf->head;
1331	p = iov->iov_base + iov->iov_len;
1332	mic.data = (u8 *)(p + 1);
1333
1334	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1335	status = -EIO; /* XXX? */
1336	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1337		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1338	else if (maj_stat)
1339		return status;
1340	q = xdr_encode_opaque(p, NULL, mic.len);
1341
1342	offset = (u8 *)q - (u8 *)p;
1343	iov->iov_len += offset;
1344	snd_buf->len += offset;
1345	return 0;
1346}
1347
1348static void
1349priv_release_snd_buf(struct rpc_rqst *rqstp)
1350{
1351	int i;
1352
1353	for (i=0; i < rqstp->rq_enc_pages_num; i++)
1354		__free_page(rqstp->rq_enc_pages[i]);
1355	kfree(rqstp->rq_enc_pages);
1356}
1357
1358static int
1359alloc_enc_pages(struct rpc_rqst *rqstp)
1360{
1361	struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
1362	int first, last, i;
1363
1364	if (snd_buf->page_len == 0) {
1365		rqstp->rq_enc_pages_num = 0;
1366		return 0;
1367	}
1368
1369	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1370	last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT;
1371	rqstp->rq_enc_pages_num = last - first + 1 + 1;
1372	rqstp->rq_enc_pages
1373		= kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *),
1374				GFP_NOFS);
1375	if (!rqstp->rq_enc_pages)
1376		goto out;
1377	for (i=0; i < rqstp->rq_enc_pages_num; i++) {
1378		rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS);
1379		if (rqstp->rq_enc_pages[i] == NULL)
1380			goto out_free;
1381	}
1382	rqstp->rq_release_snd_buf = priv_release_snd_buf;
1383	return 0;
1384out_free:
1385	rqstp->rq_enc_pages_num = i;
1386	priv_release_snd_buf(rqstp);
1387out:
1388	return -EAGAIN;
1389}
1390
1391static inline int
1392gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1393		  kxdreproc_t encode, struct rpc_rqst *rqstp,
1394		  __be32 *p, void *obj)
1395{
1396	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1397	u32		offset;
1398	u32             maj_stat;
1399	int		status;
1400	__be32		*opaque_len;
1401	struct page	**inpages;
1402	int		first;
1403	int		pad;
1404	struct kvec	*iov;
1405	char		*tmp;
1406
1407	opaque_len = p++;
1408	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1409	*p++ = htonl(rqstp->rq_seqno);
1410
1411	gss_wrap_req_encode(encode, rqstp, p, obj);
1412
1413	status = alloc_enc_pages(rqstp);
1414	if (status)
1415		return status;
1416	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1417	inpages = snd_buf->pages + first;
1418	snd_buf->pages = rqstp->rq_enc_pages;
1419	snd_buf->page_base -= first << PAGE_CACHE_SHIFT;
1420	/*
1421	 * Give the tail its own page, in case we need extra space in the
1422	 * head when wrapping:
1423	 *
1424	 * call_allocate() allocates twice the slack space required
1425	 * by the authentication flavor to rq_callsize.
1426	 * For GSS, slack is GSS_CRED_SLACK.
1427	 */
1428	if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
1429		tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
1430		memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
1431		snd_buf->tail[0].iov_base = tmp;
1432	}
1433	maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
1434	/* slack space should prevent this ever happening: */
1435	BUG_ON(snd_buf->len > snd_buf->buflen);
1436	status = -EIO;
1437	/* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
1438	 * done anyway, so it's safe to put the request on the wire: */
1439	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1440		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1441	else if (maj_stat)
1442		return status;
1443
1444	*opaque_len = htonl(snd_buf->len - offset);
1445	/* guess whether we're in the head or the tail: */
1446	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1447		iov = snd_buf->tail;
1448	else
1449		iov = snd_buf->head;
1450	p = iov->iov_base + iov->iov_len;
1451	pad = 3 - ((snd_buf->len - offset - 1) & 3);
1452	memset(p, 0, pad);
1453	iov->iov_len += pad;
1454	snd_buf->len += pad;
1455
1456	return 0;
1457}
1458
1459static int
1460gss_wrap_req(struct rpc_task *task,
1461	     kxdreproc_t encode, void *rqstp, __be32 *p, void *obj)
1462{
1463	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1464	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1465			gc_base);
1466	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1467	int             status = -EIO;
1468
1469	dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid);
1470	if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
1471		/* The spec seems a little ambiguous here, but I think that not
1472		 * wrapping context destruction requests makes the most sense.
1473		 */
1474		gss_wrap_req_encode(encode, rqstp, p, obj);
1475		status = 0;
1476		goto out;
1477	}
1478	switch (gss_cred->gc_service) {
1479	case RPC_GSS_SVC_NONE:
1480		gss_wrap_req_encode(encode, rqstp, p, obj);
1481		status = 0;
1482		break;
1483	case RPC_GSS_SVC_INTEGRITY:
1484		status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj);
1485		break;
1486	case RPC_GSS_SVC_PRIVACY:
1487		status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj);
1488		break;
1489	}
1490out:
1491	gss_put_ctx(ctx);
1492	dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status);
1493	return status;
1494}
1495
1496static inline int
1497gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1498		struct rpc_rqst *rqstp, __be32 **p)
1499{
1500	struct xdr_buf	*rcv_buf = &rqstp->rq_rcv_buf;
1501	struct xdr_buf integ_buf;
1502	struct xdr_netobj mic;
1503	u32 data_offset, mic_offset;
1504	u32 integ_len;
1505	u32 maj_stat;
1506	int status = -EIO;
1507
1508	integ_len = ntohl(*(*p)++);
1509	if (integ_len & 3)
1510		return status;
1511	data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1512	mic_offset = integ_len + data_offset;
1513	if (mic_offset > rcv_buf->len)
1514		return status;
1515	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1516		return status;
1517
1518	if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
1519				mic_offset - data_offset))
1520		return status;
1521
1522	if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
1523		return status;
1524
1525	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1526	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1527		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1528	if (maj_stat != GSS_S_COMPLETE)
1529		return status;
1530	return 0;
1531}
1532
1533static inline int
1534gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1535		struct rpc_rqst *rqstp, __be32 **p)
1536{
1537	struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1538	u32 offset;
1539	u32 opaque_len;
1540	u32 maj_stat;
1541	int status = -EIO;
1542
1543	opaque_len = ntohl(*(*p)++);
1544	offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1545	if (offset + opaque_len > rcv_buf->len)
1546		return status;
1547	/* remove padding: */
1548	rcv_buf->len = offset + opaque_len;
1549
1550	maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
1551	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1552		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1553	if (maj_stat != GSS_S_COMPLETE)
1554		return status;
1555	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1556		return status;
1557
1558	return 0;
1559}
1560
1561static int
1562gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
1563		      __be32 *p, void *obj)
1564{
1565	struct xdr_stream xdr;
1566
1567	xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p);
1568	return decode(rqstp, &xdr, obj);
1569}
1570
1571static int
1572gss_unwrap_resp(struct rpc_task *task,
1573		kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj)
1574{
1575	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1576	struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1577			gc_base);
1578	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1579	__be32		*savedp = p;
1580	struct kvec	*head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
1581	int		savedlen = head->iov_len;
1582	int             status = -EIO;
1583
1584	if (ctx->gc_proc != RPC_GSS_PROC_DATA)
1585		goto out_decode;
1586	switch (gss_cred->gc_service) {
1587	case RPC_GSS_SVC_NONE:
1588		break;
1589	case RPC_GSS_SVC_INTEGRITY:
1590		status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
1591		if (status)
1592			goto out;
1593		break;
1594	case RPC_GSS_SVC_PRIVACY:
1595		status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
1596		if (status)
1597			goto out;
1598		break;
1599	}
1600	/* take into account extra slack for integrity and privacy cases: */
1601	cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
1602						+ (savedlen - head->iov_len);
1603out_decode:
1604	status = gss_unwrap_req_decode(decode, rqstp, p, obj);
1605out:
1606	gss_put_ctx(ctx);
1607	dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
1608			status);
1609	return status;
1610}
1611
1612static const struct rpc_authops authgss_ops = {
1613	.owner		= THIS_MODULE,
1614	.au_flavor	= RPC_AUTH_GSS,
1615	.au_name	= "RPCSEC_GSS",
1616	.create		= gss_create,
1617	.destroy	= gss_destroy,
1618	.lookup_cred	= gss_lookup_cred,
1619	.crcreate	= gss_create_cred,
1620	.pipes_create	= gss_pipes_dentries_create,
1621	.pipes_destroy	= gss_pipes_dentries_destroy,
1622};
1623
1624static const struct rpc_credops gss_credops = {
1625	.cr_name	= "AUTH_GSS",
1626	.crdestroy	= gss_destroy_cred,
1627	.cr_init	= gss_cred_init,
1628	.crbind		= rpcauth_generic_bind_cred,
1629	.crmatch	= gss_match,
1630	.crmarshal	= gss_marshal,
1631	.crrefresh	= gss_refresh,
1632	.crvalidate	= gss_validate,
1633	.crwrap_req	= gss_wrap_req,
1634	.crunwrap_resp	= gss_unwrap_resp,
1635};
1636
1637static const struct rpc_credops gss_nullops = {
1638	.cr_name	= "AUTH_GSS",
1639	.crdestroy	= gss_destroy_nullcred,
1640	.crbind		= rpcauth_generic_bind_cred,
1641	.crmatch	= gss_match,
1642	.crmarshal	= gss_marshal,
1643	.crrefresh	= gss_refresh_null,
1644	.crvalidate	= gss_validate,
1645	.crwrap_req	= gss_wrap_req,
1646	.crunwrap_resp	= gss_unwrap_resp,
1647};
1648
1649static const struct rpc_pipe_ops gss_upcall_ops_v0 = {
1650	.upcall		= rpc_pipe_generic_upcall,
1651	.downcall	= gss_pipe_downcall,
1652	.destroy_msg	= gss_pipe_destroy_msg,
1653	.open_pipe	= gss_pipe_open_v0,
1654	.release_pipe	= gss_pipe_release,
1655};
1656
1657static const struct rpc_pipe_ops gss_upcall_ops_v1 = {
1658	.upcall		= rpc_pipe_generic_upcall,
1659	.downcall	= gss_pipe_downcall,
1660	.destroy_msg	= gss_pipe_destroy_msg,
1661	.open_pipe	= gss_pipe_open_v1,
1662	.release_pipe	= gss_pipe_release,
1663};
1664
1665static __net_init int rpcsec_gss_init_net(struct net *net)
1666{
1667	return gss_svc_init_net(net);
1668}
1669
1670static __net_exit void rpcsec_gss_exit_net(struct net *net)
1671{
1672	gss_svc_shutdown_net(net);
1673}
1674
1675static struct pernet_operations rpcsec_gss_net_ops = {
1676	.init = rpcsec_gss_init_net,
1677	.exit = rpcsec_gss_exit_net,
1678};
1679
1680/*
1681 * Initialize RPCSEC_GSS module
1682 */
1683static int __init init_rpcsec_gss(void)
1684{
1685	int err = 0;
1686
1687	err = rpcauth_register(&authgss_ops);
1688	if (err)
1689		goto out;
1690	err = gss_svc_init();
1691	if (err)
1692		goto out_unregister;
1693	err = register_pernet_subsys(&rpcsec_gss_net_ops);
1694	if (err)
1695		goto out_svc_exit;
1696	rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
1697	return 0;
1698out_svc_exit:
1699	gss_svc_shutdown();
1700out_unregister:
1701	rpcauth_unregister(&authgss_ops);
1702out:
1703	return err;
1704}
1705
1706static void __exit exit_rpcsec_gss(void)
1707{
1708	unregister_pernet_subsys(&rpcsec_gss_net_ops);
1709	gss_svc_shutdown();
1710	rpcauth_unregister(&authgss_ops);
1711	rcu_barrier(); /* Wait for completion of call_rcu()'s */
1712}
1713
1714MODULE_LICENSE("GPL");
1715module_param_named(expired_cred_retry_delay,
1716		   gss_expired_cred_retry_delay,
1717		   uint, 0644);
1718MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until "
1719		"the RPC engine retries an expired credential");
1720
1721module_init(init_rpcsec_gss)
1722module_exit(exit_rpcsec_gss)
1723