auth_gss.c revision a4f0835c604f80f945ab3e72ffd00547145c4b2b
1/*
2 * linux/net/sunrpc/auth_gss/auth_gss.c
3 *
4 * RPCSEC_GSS client authentication.
5 *
6 *  Copyright (c) 2000 The Regents of the University of Michigan.
7 *  All rights reserved.
8 *
9 *  Dug Song       <dugsong@monkey.org>
10 *  Andy Adamson   <andros@umich.edu>
11 *
12 *  Redistribution and use in source and binary forms, with or without
13 *  modification, are permitted provided that the following conditions
14 *  are met:
15 *
16 *  1. Redistributions of source code must retain the above copyright
17 *     notice, this list of conditions and the following disclaimer.
18 *  2. Redistributions in binary form must reproduce the above copyright
19 *     notice, this list of conditions and the following disclaimer in the
20 *     documentation and/or other materials provided with the distribution.
21 *  3. Neither the name of the University nor the names of its
22 *     contributors may be used to endorse or promote products derived
23 *     from this software without specific prior written permission.
24 *
25 *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
26 *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
27 *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
29 *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30 *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31 *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
32 *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 */
37
38
39#include <linux/module.h>
40#include <linux/init.h>
41#include <linux/types.h>
42#include <linux/slab.h>
43#include <linux/sched.h>
44#include <linux/pagemap.h>
45#include <linux/sunrpc/clnt.h>
46#include <linux/sunrpc/auth.h>
47#include <linux/sunrpc/auth_gss.h>
48#include <linux/sunrpc/svcauth_gss.h>
49#include <linux/sunrpc/gss_err.h>
50#include <linux/workqueue.h>
51#include <linux/sunrpc/rpc_pipe_fs.h>
52#include <linux/sunrpc/gss_api.h>
53#include <asm/uaccess.h>
54
55static const struct rpc_authops authgss_ops;
56
57static const struct rpc_credops gss_credops;
58static const struct rpc_credops gss_nullops;
59
60#define GSS_RETRY_EXPIRED 5
61static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED;
62
63#ifdef RPC_DEBUG
64# define RPCDBG_FACILITY	RPCDBG_AUTH
65#endif
66
67#define GSS_CRED_SLACK		(RPC_MAX_AUTH_SIZE * 2)
68/* length of a krb5 verifier (48), plus data added before arguments when
69 * using integrity (two 4-byte integers): */
70#define GSS_VERF_SLACK		100
71
72struct gss_auth {
73	struct kref kref;
74	struct rpc_auth rpc_auth;
75	struct gss_api_mech *mech;
76	enum rpc_gss_svc service;
77	struct rpc_clnt *client;
78	/*
79	 * There are two upcall pipes; dentry[1], named "gssd", is used
80	 * for the new text-based upcall; dentry[0] is named after the
81	 * mechanism (for example, "krb5") and exists for
82	 * backwards-compatibility with older gssd's.
83	 */
84	struct rpc_pipe *pipe[2];
85};
86
87/* pipe_version >= 0 if and only if someone has a pipe open. */
88static int pipe_version = -1;
89static atomic_t pipe_users = ATOMIC_INIT(0);
90static DEFINE_SPINLOCK(pipe_version_lock);
91static struct rpc_wait_queue pipe_version_rpc_waitqueue;
92static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
93
94static void gss_free_ctx(struct gss_cl_ctx *);
95static const struct rpc_pipe_ops gss_upcall_ops_v0;
96static const struct rpc_pipe_ops gss_upcall_ops_v1;
97
98static inline struct gss_cl_ctx *
99gss_get_ctx(struct gss_cl_ctx *ctx)
100{
101	atomic_inc(&ctx->count);
102	return ctx;
103}
104
105static inline void
106gss_put_ctx(struct gss_cl_ctx *ctx)
107{
108	if (atomic_dec_and_test(&ctx->count))
109		gss_free_ctx(ctx);
110}
111
112/* gss_cred_set_ctx:
113 * called by gss_upcall_callback and gss_create_upcall in order
114 * to set the gss context. The actual exchange of an old context
115 * and a new one is protected by the pipe->lock.
116 */
117static void
118gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
119{
120	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
121
122	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
123		return;
124	gss_get_ctx(ctx);
125	rcu_assign_pointer(gss_cred->gc_ctx, ctx);
126	set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
127	smp_mb__before_clear_bit();
128	clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
129}
130
131static const void *
132simple_get_bytes(const void *p, const void *end, void *res, size_t len)
133{
134	const void *q = (const void *)((const char *)p + len);
135	if (unlikely(q > end || q < p))
136		return ERR_PTR(-EFAULT);
137	memcpy(res, p, len);
138	return q;
139}
140
141static inline const void *
142simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
143{
144	const void *q;
145	unsigned int len;
146
147	p = simple_get_bytes(p, end, &len, sizeof(len));
148	if (IS_ERR(p))
149		return p;
150	q = (const void *)((const char *)p + len);
151	if (unlikely(q > end || q < p))
152		return ERR_PTR(-EFAULT);
153	dest->data = kmemdup(p, len, GFP_NOFS);
154	if (unlikely(dest->data == NULL))
155		return ERR_PTR(-ENOMEM);
156	dest->len = len;
157	return q;
158}
159
160static struct gss_cl_ctx *
161gss_cred_get_ctx(struct rpc_cred *cred)
162{
163	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
164	struct gss_cl_ctx *ctx = NULL;
165
166	rcu_read_lock();
167	if (gss_cred->gc_ctx)
168		ctx = gss_get_ctx(gss_cred->gc_ctx);
169	rcu_read_unlock();
170	return ctx;
171}
172
173static struct gss_cl_ctx *
174gss_alloc_context(void)
175{
176	struct gss_cl_ctx *ctx;
177
178	ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
179	if (ctx != NULL) {
180		ctx->gc_proc = RPC_GSS_PROC_DATA;
181		ctx->gc_seq = 1;	/* NetApp 6.4R1 doesn't accept seq. no. 0 */
182		spin_lock_init(&ctx->gc_seq_lock);
183		atomic_set(&ctx->count,1);
184	}
185	return ctx;
186}
187
188#define GSSD_MIN_TIMEOUT (60 * 60)
189static const void *
190gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
191{
192	const void *q;
193	unsigned int seclen;
194	unsigned int timeout;
195	unsigned long now = jiffies;
196	u32 window_size;
197	int ret;
198
199	/* First unsigned int gives the remaining lifetime in seconds of the
200	 * credential - e.g. the remaining TGT lifetime for Kerberos or
201	 * the -t value passed to GSSD.
202	 */
203	p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
204	if (IS_ERR(p))
205		goto err;
206	if (timeout == 0)
207		timeout = GSSD_MIN_TIMEOUT;
208	ctx->gc_expiry = now + ((unsigned long)timeout * HZ);
209	/* Sequence number window. Determines the maximum number of
210	 * simultaneous requests
211	 */
212	p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
213	if (IS_ERR(p))
214		goto err;
215	ctx->gc_win = window_size;
216	/* gssd signals an error by passing ctx->gc_win = 0: */
217	if (ctx->gc_win == 0) {
218		/*
219		 * in which case, p points to an error code. Anything other
220		 * than -EKEYEXPIRED gets converted to -EACCES.
221		 */
222		p = simple_get_bytes(p, end, &ret, sizeof(ret));
223		if (!IS_ERR(p))
224			p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) :
225						    ERR_PTR(-EACCES);
226		goto err;
227	}
228	/* copy the opaque wire context */
229	p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
230	if (IS_ERR(p))
231		goto err;
232	/* import the opaque security context */
233	p  = simple_get_bytes(p, end, &seclen, sizeof(seclen));
234	if (IS_ERR(p))
235		goto err;
236	q = (const void *)((const char *)p + seclen);
237	if (unlikely(q > end || q < p)) {
238		p = ERR_PTR(-EFAULT);
239		goto err;
240	}
241	ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, GFP_NOFS);
242	if (ret < 0) {
243		p = ERR_PTR(ret);
244		goto err;
245	}
246	dprintk("RPC:       %s Success. gc_expiry %lu now %lu timeout %u\n",
247		__func__, ctx->gc_expiry, now, timeout);
248	return q;
249err:
250	dprintk("RPC:       %s returns %ld gc_expiry %lu now %lu timeout %u\n",
251		__func__, -PTR_ERR(p), ctx->gc_expiry, now, timeout);
252	return p;
253}
254
255#define UPCALL_BUF_LEN 128
256
257struct gss_upcall_msg {
258	atomic_t count;
259	uid_t	uid;
260	struct rpc_pipe_msg msg;
261	struct list_head list;
262	struct gss_auth *auth;
263	struct rpc_pipe *pipe;
264	struct rpc_wait_queue rpc_waitqueue;
265	wait_queue_head_t waitqueue;
266	struct gss_cl_ctx *ctx;
267	char databuf[UPCALL_BUF_LEN];
268};
269
270static int get_pipe_version(void)
271{
272	int ret;
273
274	spin_lock(&pipe_version_lock);
275	if (pipe_version >= 0) {
276		atomic_inc(&pipe_users);
277		ret = pipe_version;
278	} else
279		ret = -EAGAIN;
280	spin_unlock(&pipe_version_lock);
281	return ret;
282}
283
284static void put_pipe_version(void)
285{
286	if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
287		pipe_version = -1;
288		spin_unlock(&pipe_version_lock);
289	}
290}
291
292static void
293gss_release_msg(struct gss_upcall_msg *gss_msg)
294{
295	if (!atomic_dec_and_test(&gss_msg->count))
296		return;
297	put_pipe_version();
298	BUG_ON(!list_empty(&gss_msg->list));
299	if (gss_msg->ctx != NULL)
300		gss_put_ctx(gss_msg->ctx);
301	rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
302	kfree(gss_msg);
303}
304
305static struct gss_upcall_msg *
306__gss_find_upcall(struct rpc_pipe *pipe, uid_t uid)
307{
308	struct gss_upcall_msg *pos;
309	list_for_each_entry(pos, &pipe->in_downcall, list) {
310		if (pos->uid != uid)
311			continue;
312		atomic_inc(&pos->count);
313		dprintk("RPC:       %s found msg %p\n", __func__, pos);
314		return pos;
315	}
316	dprintk("RPC:       %s found nothing\n", __func__);
317	return NULL;
318}
319
320/* Try to add an upcall to the pipefs queue.
321 * If an upcall owned by our uid already exists, then we return a reference
322 * to that upcall instead of adding the new upcall.
323 */
324static inline struct gss_upcall_msg *
325gss_add_msg(struct gss_upcall_msg *gss_msg)
326{
327	struct rpc_pipe *pipe = gss_msg->pipe;
328	struct gss_upcall_msg *old;
329
330	spin_lock(&pipe->lock);
331	old = __gss_find_upcall(pipe, gss_msg->uid);
332	if (old == NULL) {
333		atomic_inc(&gss_msg->count);
334		list_add(&gss_msg->list, &pipe->in_downcall);
335	} else
336		gss_msg = old;
337	spin_unlock(&pipe->lock);
338	return gss_msg;
339}
340
341static void
342__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
343{
344	list_del_init(&gss_msg->list);
345	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
346	wake_up_all(&gss_msg->waitqueue);
347	atomic_dec(&gss_msg->count);
348}
349
350static void
351gss_unhash_msg(struct gss_upcall_msg *gss_msg)
352{
353	struct rpc_pipe *pipe = gss_msg->pipe;
354
355	if (list_empty(&gss_msg->list))
356		return;
357	spin_lock(&pipe->lock);
358	if (!list_empty(&gss_msg->list))
359		__gss_unhash_msg(gss_msg);
360	spin_unlock(&pipe->lock);
361}
362
363static void
364gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg)
365{
366	switch (gss_msg->msg.errno) {
367	case 0:
368		if (gss_msg->ctx == NULL)
369			break;
370		clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
371		gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx);
372		break;
373	case -EKEYEXPIRED:
374		set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
375	}
376	gss_cred->gc_upcall_timestamp = jiffies;
377	gss_cred->gc_upcall = NULL;
378	rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
379}
380
381static void
382gss_upcall_callback(struct rpc_task *task)
383{
384	struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred,
385			struct gss_cred, gc_base);
386	struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
387	struct rpc_pipe *pipe = gss_msg->pipe;
388
389	spin_lock(&pipe->lock);
390	gss_handle_downcall_result(gss_cred, gss_msg);
391	spin_unlock(&pipe->lock);
392	task->tk_status = gss_msg->msg.errno;
393	gss_release_msg(gss_msg);
394}
395
396static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
397{
398	gss_msg->msg.data = &gss_msg->uid;
399	gss_msg->msg.len = sizeof(gss_msg->uid);
400}
401
402static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
403				struct rpc_clnt *clnt,
404				const char *service_name)
405{
406	struct gss_api_mech *mech = gss_msg->auth->mech;
407	char *p = gss_msg->databuf;
408	int len = 0;
409
410	gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
411				   mech->gm_name,
412				   gss_msg->uid);
413	p += gss_msg->msg.len;
414	if (clnt->cl_principal) {
415		len = sprintf(p, "target=%s ", clnt->cl_principal);
416		p += len;
417		gss_msg->msg.len += len;
418	}
419	if (service_name != NULL) {
420		len = sprintf(p, "service=%s ", service_name);
421		p += len;
422		gss_msg->msg.len += len;
423	}
424	if (mech->gm_upcall_enctypes) {
425		len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes);
426		p += len;
427		gss_msg->msg.len += len;
428	}
429	len = sprintf(p, "\n");
430	gss_msg->msg.len += len;
431
432	gss_msg->msg.data = gss_msg->databuf;
433	BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
434}
435
436static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
437				struct rpc_clnt *clnt,
438				const char *service_name)
439{
440	if (pipe_version == 0)
441		gss_encode_v0_msg(gss_msg);
442	else /* pipe_version == 1 */
443		gss_encode_v1_msg(gss_msg, clnt, service_name);
444}
445
446static struct gss_upcall_msg *
447gss_alloc_msg(struct gss_auth *gss_auth, struct rpc_clnt *clnt,
448		uid_t uid, const char *service_name)
449{
450	struct gss_upcall_msg *gss_msg;
451	int vers;
452
453	gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
454	if (gss_msg == NULL)
455		return ERR_PTR(-ENOMEM);
456	vers = get_pipe_version();
457	if (vers < 0) {
458		kfree(gss_msg);
459		return ERR_PTR(vers);
460	}
461	gss_msg->pipe = gss_auth->pipe[vers];
462	INIT_LIST_HEAD(&gss_msg->list);
463	rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
464	init_waitqueue_head(&gss_msg->waitqueue);
465	atomic_set(&gss_msg->count, 1);
466	gss_msg->uid = uid;
467	gss_msg->auth = gss_auth;
468	gss_encode_msg(gss_msg, clnt, service_name);
469	return gss_msg;
470}
471
472static struct gss_upcall_msg *
473gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
474{
475	struct gss_cred *gss_cred = container_of(cred,
476			struct gss_cred, gc_base);
477	struct gss_upcall_msg *gss_new, *gss_msg;
478	uid_t uid = cred->cr_uid;
479
480	gss_new = gss_alloc_msg(gss_auth, clnt, uid, gss_cred->gc_principal);
481	if (IS_ERR(gss_new))
482		return gss_new;
483	gss_msg = gss_add_msg(gss_new);
484	if (gss_msg == gss_new) {
485		int res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg);
486		if (res) {
487			gss_unhash_msg(gss_new);
488			gss_msg = ERR_PTR(res);
489		}
490	} else
491		gss_release_msg(gss_new);
492	return gss_msg;
493}
494
495static void warn_gssd(void)
496{
497	static unsigned long ratelimit;
498	unsigned long now = jiffies;
499
500	if (time_after(now, ratelimit)) {
501		printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
502				"Please check user daemon is running.\n");
503		ratelimit = now + 15*HZ;
504	}
505}
506
507static inline int
508gss_refresh_upcall(struct rpc_task *task)
509{
510	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
511	struct gss_auth *gss_auth = container_of(cred->cr_auth,
512			struct gss_auth, rpc_auth);
513	struct gss_cred *gss_cred = container_of(cred,
514			struct gss_cred, gc_base);
515	struct gss_upcall_msg *gss_msg;
516	struct rpc_pipe *pipe;
517	int err = 0;
518
519	dprintk("RPC: %5u %s for uid %u\n",
520		task->tk_pid, __func__, cred->cr_uid);
521	gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
522	if (PTR_ERR(gss_msg) == -EAGAIN) {
523		/* XXX: warning on the first, under the assumption we
524		 * shouldn't normally hit this case on a refresh. */
525		warn_gssd();
526		task->tk_timeout = 15*HZ;
527		rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
528		return -EAGAIN;
529	}
530	if (IS_ERR(gss_msg)) {
531		err = PTR_ERR(gss_msg);
532		goto out;
533	}
534	pipe = gss_msg->pipe;
535	spin_lock(&pipe->lock);
536	if (gss_cred->gc_upcall != NULL)
537		rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
538	else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
539		task->tk_timeout = 0;
540		gss_cred->gc_upcall = gss_msg;
541		/* gss_upcall_callback will release the reference to gss_upcall_msg */
542		atomic_inc(&gss_msg->count);
543		rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
544	} else {
545		gss_handle_downcall_result(gss_cred, gss_msg);
546		err = gss_msg->msg.errno;
547	}
548	spin_unlock(&pipe->lock);
549	gss_release_msg(gss_msg);
550out:
551	dprintk("RPC: %5u %s for uid %u result %d\n",
552		task->tk_pid, __func__, cred->cr_uid, err);
553	return err;
554}
555
556static inline int
557gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
558{
559	struct rpc_pipe *pipe;
560	struct rpc_cred *cred = &gss_cred->gc_base;
561	struct gss_upcall_msg *gss_msg;
562	DEFINE_WAIT(wait);
563	int err = 0;
564
565	dprintk("RPC:       %s for uid %u\n", __func__, cred->cr_uid);
566retry:
567	gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
568	if (PTR_ERR(gss_msg) == -EAGAIN) {
569		err = wait_event_interruptible_timeout(pipe_version_waitqueue,
570				pipe_version >= 0, 15*HZ);
571		if (pipe_version < 0) {
572			warn_gssd();
573			err = -EACCES;
574		}
575		if (err)
576			goto out;
577		goto retry;
578	}
579	if (IS_ERR(gss_msg)) {
580		err = PTR_ERR(gss_msg);
581		goto out;
582	}
583	pipe = gss_msg->pipe;
584	for (;;) {
585		prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE);
586		spin_lock(&pipe->lock);
587		if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
588			break;
589		}
590		spin_unlock(&pipe->lock);
591		if (fatal_signal_pending(current)) {
592			err = -ERESTARTSYS;
593			goto out_intr;
594		}
595		schedule();
596	}
597	if (gss_msg->ctx)
598		gss_cred_set_ctx(cred, gss_msg->ctx);
599	else
600		err = gss_msg->msg.errno;
601	spin_unlock(&pipe->lock);
602out_intr:
603	finish_wait(&gss_msg->waitqueue, &wait);
604	gss_release_msg(gss_msg);
605out:
606	dprintk("RPC:       %s for uid %u result %d\n",
607		__func__, cred->cr_uid, err);
608	return err;
609}
610
611#define MSG_BUF_MAXSIZE 1024
612
613static ssize_t
614gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
615{
616	const void *p, *end;
617	void *buf;
618	struct gss_upcall_msg *gss_msg;
619	struct rpc_pipe *pipe = RPC_I(filp->f_dentry->d_inode)->pipe;
620	struct gss_cl_ctx *ctx;
621	uid_t uid;
622	ssize_t err = -EFBIG;
623
624	if (mlen > MSG_BUF_MAXSIZE)
625		goto out;
626	err = -ENOMEM;
627	buf = kmalloc(mlen, GFP_NOFS);
628	if (!buf)
629		goto out;
630
631	err = -EFAULT;
632	if (copy_from_user(buf, src, mlen))
633		goto err;
634
635	end = (const void *)((char *)buf + mlen);
636	p = simple_get_bytes(buf, end, &uid, sizeof(uid));
637	if (IS_ERR(p)) {
638		err = PTR_ERR(p);
639		goto err;
640	}
641
642	err = -ENOMEM;
643	ctx = gss_alloc_context();
644	if (ctx == NULL)
645		goto err;
646
647	err = -ENOENT;
648	/* Find a matching upcall */
649	spin_lock(&pipe->lock);
650	gss_msg = __gss_find_upcall(pipe, uid);
651	if (gss_msg == NULL) {
652		spin_unlock(&pipe->lock);
653		goto err_put_ctx;
654	}
655	list_del_init(&gss_msg->list);
656	spin_unlock(&pipe->lock);
657
658	p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
659	if (IS_ERR(p)) {
660		err = PTR_ERR(p);
661		switch (err) {
662		case -EACCES:
663		case -EKEYEXPIRED:
664			gss_msg->msg.errno = err;
665			err = mlen;
666			break;
667		case -EFAULT:
668		case -ENOMEM:
669		case -EINVAL:
670		case -ENOSYS:
671			gss_msg->msg.errno = -EAGAIN;
672			break;
673		default:
674			printk(KERN_CRIT "%s: bad return from "
675				"gss_fill_context: %zd\n", __func__, err);
676			BUG();
677		}
678		goto err_release_msg;
679	}
680	gss_msg->ctx = gss_get_ctx(ctx);
681	err = mlen;
682
683err_release_msg:
684	spin_lock(&pipe->lock);
685	__gss_unhash_msg(gss_msg);
686	spin_unlock(&pipe->lock);
687	gss_release_msg(gss_msg);
688err_put_ctx:
689	gss_put_ctx(ctx);
690err:
691	kfree(buf);
692out:
693	dprintk("RPC:       %s returning %Zd\n", __func__, err);
694	return err;
695}
696
697static int gss_pipe_open(struct inode *inode, int new_version)
698{
699	int ret = 0;
700
701	spin_lock(&pipe_version_lock);
702	if (pipe_version < 0) {
703		/* First open of any gss pipe determines the version: */
704		pipe_version = new_version;
705		rpc_wake_up(&pipe_version_rpc_waitqueue);
706		wake_up(&pipe_version_waitqueue);
707	} else if (pipe_version != new_version) {
708		/* Trying to open a pipe of a different version */
709		ret = -EBUSY;
710		goto out;
711	}
712	atomic_inc(&pipe_users);
713out:
714	spin_unlock(&pipe_version_lock);
715	return ret;
716
717}
718
719static int gss_pipe_open_v0(struct inode *inode)
720{
721	return gss_pipe_open(inode, 0);
722}
723
724static int gss_pipe_open_v1(struct inode *inode)
725{
726	return gss_pipe_open(inode, 1);
727}
728
729static void
730gss_pipe_release(struct inode *inode)
731{
732	struct rpc_pipe *pipe = RPC_I(inode)->pipe;
733	struct gss_upcall_msg *gss_msg;
734
735restart:
736	spin_lock(&pipe->lock);
737	list_for_each_entry(gss_msg, &pipe->in_downcall, list) {
738
739		if (!list_empty(&gss_msg->msg.list))
740			continue;
741		gss_msg->msg.errno = -EPIPE;
742		atomic_inc(&gss_msg->count);
743		__gss_unhash_msg(gss_msg);
744		spin_unlock(&pipe->lock);
745		gss_release_msg(gss_msg);
746		goto restart;
747	}
748	spin_unlock(&pipe->lock);
749
750	put_pipe_version();
751}
752
753static void
754gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
755{
756	struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
757
758	if (msg->errno < 0) {
759		dprintk("RPC:       %s releasing msg %p\n",
760			__func__, gss_msg);
761		atomic_inc(&gss_msg->count);
762		gss_unhash_msg(gss_msg);
763		if (msg->errno == -ETIMEDOUT)
764			warn_gssd();
765		gss_release_msg(gss_msg);
766	}
767}
768
769static void gss_pipes_dentries_destroy(struct rpc_auth *auth)
770{
771	struct gss_auth *gss_auth;
772
773	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
774	if (gss_auth->pipe[0]->dentry)
775		rpc_unlink(gss_auth->pipe[0]->dentry);
776	if (gss_auth->pipe[1]->dentry)
777		rpc_unlink(gss_auth->pipe[1]->dentry);
778}
779
780static int gss_pipes_dentries_create(struct rpc_auth *auth)
781{
782	int err;
783	struct gss_auth *gss_auth;
784	struct rpc_clnt *clnt;
785
786	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
787	clnt = gss_auth->client;
788
789	gss_auth->pipe[1]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
790						      "gssd",
791						      clnt, gss_auth->pipe[1]);
792	if (IS_ERR(gss_auth->pipe[1]->dentry))
793		return PTR_ERR(gss_auth->pipe[1]->dentry);
794	gss_auth->pipe[0]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
795						      gss_auth->mech->gm_name,
796						      clnt, gss_auth->pipe[0]);
797	if (IS_ERR(gss_auth->pipe[0]->dentry)) {
798		err = PTR_ERR(gss_auth->pipe[0]->dentry);
799		goto err_unlink_pipe_1;
800	}
801	return 0;
802
803err_unlink_pipe_1:
804	rpc_unlink(gss_auth->pipe[1]->dentry);
805	return err;
806}
807
808static void gss_pipes_dentries_destroy_net(struct rpc_clnt *clnt,
809					   struct rpc_auth *auth)
810{
811	struct net *net = rpc_net_ns(clnt);
812	struct super_block *sb;
813
814	sb = rpc_get_sb_net(net);
815	if (sb) {
816		if (clnt->cl_dentry)
817			gss_pipes_dentries_destroy(auth);
818		rpc_put_sb_net(net);
819	}
820}
821
822static int gss_pipes_dentries_create_net(struct rpc_clnt *clnt,
823					 struct rpc_auth *auth)
824{
825	struct net *net = rpc_net_ns(clnt);
826	struct super_block *sb;
827	int err = 0;
828
829	sb = rpc_get_sb_net(net);
830	if (sb) {
831		if (clnt->cl_dentry)
832			err = gss_pipes_dentries_create(auth);
833		rpc_put_sb_net(net);
834	}
835	return err;
836}
837
838/*
839 * NOTE: we have the opportunity to use different
840 * parameters based on the input flavor (which must be a pseudoflavor)
841 */
842static struct rpc_auth *
843gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
844{
845	struct gss_auth *gss_auth;
846	struct rpc_auth * auth;
847	int err = -ENOMEM; /* XXX? */
848
849	dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
850
851	if (!try_module_get(THIS_MODULE))
852		return ERR_PTR(err);
853	if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
854		goto out_dec;
855	gss_auth->client = clnt;
856	err = -EINVAL;
857	gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
858	if (!gss_auth->mech) {
859		printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n",
860				__func__, flavor);
861		goto err_free;
862	}
863	gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
864	if (gss_auth->service == 0)
865		goto err_put_mech;
866	auth = &gss_auth->rpc_auth;
867	auth->au_cslack = GSS_CRED_SLACK >> 2;
868	auth->au_rslack = GSS_VERF_SLACK >> 2;
869	auth->au_ops = &authgss_ops;
870	auth->au_flavor = flavor;
871	atomic_set(&auth->au_count, 1);
872	kref_init(&gss_auth->kref);
873
874	/*
875	 * Note: if we created the old pipe first, then someone who
876	 * examined the directory at the right moment might conclude
877	 * that we supported only the old pipe.  So we instead create
878	 * the new pipe first.
879	 */
880	gss_auth->pipe[1] = rpc_mkpipe_data(&gss_upcall_ops_v1,
881					    RPC_PIPE_WAIT_FOR_OPEN);
882	if (IS_ERR(gss_auth->pipe[1])) {
883		err = PTR_ERR(gss_auth->pipe[1]);
884		goto err_put_mech;
885	}
886
887	gss_auth->pipe[0] = rpc_mkpipe_data(&gss_upcall_ops_v0,
888					    RPC_PIPE_WAIT_FOR_OPEN);
889	if (IS_ERR(gss_auth->pipe[0])) {
890		err = PTR_ERR(gss_auth->pipe[0]);
891		goto err_destroy_pipe_1;
892	}
893	err = gss_pipes_dentries_create_net(clnt, auth);
894	if (err)
895		goto err_destroy_pipe_0;
896	err = rpcauth_init_credcache(auth);
897	if (err)
898		goto err_unlink_pipes;
899
900	return auth;
901err_unlink_pipes:
902	gss_pipes_dentries_destroy_net(clnt, auth);
903err_destroy_pipe_0:
904	rpc_destroy_pipe_data(gss_auth->pipe[0]);
905err_destroy_pipe_1:
906	rpc_destroy_pipe_data(gss_auth->pipe[1]);
907err_put_mech:
908	gss_mech_put(gss_auth->mech);
909err_free:
910	kfree(gss_auth);
911out_dec:
912	module_put(THIS_MODULE);
913	return ERR_PTR(err);
914}
915
916static void
917gss_free(struct gss_auth *gss_auth)
918{
919	gss_pipes_dentries_destroy_net(gss_auth->client, &gss_auth->rpc_auth);
920	rpc_destroy_pipe_data(gss_auth->pipe[0]);
921	rpc_destroy_pipe_data(gss_auth->pipe[1]);
922	gss_mech_put(gss_auth->mech);
923
924	kfree(gss_auth);
925	module_put(THIS_MODULE);
926}
927
928static void
929gss_free_callback(struct kref *kref)
930{
931	struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
932
933	gss_free(gss_auth);
934}
935
936static void
937gss_destroy(struct rpc_auth *auth)
938{
939	struct gss_auth *gss_auth;
940
941	dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
942			auth, auth->au_flavor);
943
944	rpcauth_destroy_credcache(auth);
945
946	gss_auth = container_of(auth, struct gss_auth, rpc_auth);
947	kref_put(&gss_auth->kref, gss_free_callback);
948}
949
950/*
951 * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call
952 * to the server with the GSS control procedure field set to
953 * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
954 * all RPCSEC_GSS state associated with that context.
955 */
956static int
957gss_destroying_context(struct rpc_cred *cred)
958{
959	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
960	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
961	struct rpc_task *task;
962
963	if (gss_cred->gc_ctx == NULL ||
964	    test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
965		return 0;
966
967	gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
968	cred->cr_ops = &gss_nullops;
969
970	/* Take a reference to ensure the cred will be destroyed either
971	 * by the RPC call or by the put_rpccred() below */
972	get_rpccred(cred);
973
974	task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT);
975	if (!IS_ERR(task))
976		rpc_put_task(task);
977
978	put_rpccred(cred);
979	return 1;
980}
981
982/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
983 * to create a new cred or context, so they check that things have been
984 * allocated before freeing them. */
985static void
986gss_do_free_ctx(struct gss_cl_ctx *ctx)
987{
988	dprintk("RPC:       %s\n", __func__);
989
990	gss_delete_sec_context(&ctx->gc_gss_ctx);
991	kfree(ctx->gc_wire_ctx.data);
992	kfree(ctx);
993}
994
995static void
996gss_free_ctx_callback(struct rcu_head *head)
997{
998	struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
999	gss_do_free_ctx(ctx);
1000}
1001
1002static void
1003gss_free_ctx(struct gss_cl_ctx *ctx)
1004{
1005	call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
1006}
1007
1008static void
1009gss_free_cred(struct gss_cred *gss_cred)
1010{
1011	dprintk("RPC:       %s cred=%p\n", __func__, gss_cred);
1012	kfree(gss_cred);
1013}
1014
1015static void
1016gss_free_cred_callback(struct rcu_head *head)
1017{
1018	struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
1019	gss_free_cred(gss_cred);
1020}
1021
1022static void
1023gss_destroy_nullcred(struct rpc_cred *cred)
1024{
1025	struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
1026	struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
1027	struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
1028
1029	RCU_INIT_POINTER(gss_cred->gc_ctx, NULL);
1030	call_rcu(&cred->cr_rcu, gss_free_cred_callback);
1031	if (ctx)
1032		gss_put_ctx(ctx);
1033	kref_put(&gss_auth->kref, gss_free_callback);
1034}
1035
1036static void
1037gss_destroy_cred(struct rpc_cred *cred)
1038{
1039
1040	if (gss_destroying_context(cred))
1041		return;
1042	gss_destroy_nullcred(cred);
1043}
1044
1045/*
1046 * Lookup RPCSEC_GSS cred for the current process
1047 */
1048static struct rpc_cred *
1049gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1050{
1051	return rpcauth_lookup_credcache(auth, acred, flags);
1052}
1053
1054static struct rpc_cred *
1055gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1056{
1057	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1058	struct gss_cred	*cred = NULL;
1059	int err = -ENOMEM;
1060
1061	dprintk("RPC:       %s for uid %d, flavor %d\n",
1062		__func__, acred->uid, auth->au_flavor);
1063
1064	if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
1065		goto out_err;
1066
1067	rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
1068	/*
1069	 * Note: in order to force a call to call_refresh(), we deliberately
1070	 * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
1071	 */
1072	cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
1073	cred->gc_service = gss_auth->service;
1074	cred->gc_principal = NULL;
1075	if (acred->machine_cred)
1076		cred->gc_principal = acred->principal;
1077	kref_get(&gss_auth->kref);
1078	return &cred->gc_base;
1079
1080out_err:
1081	dprintk("RPC:       %s failed with error %d\n", __func__, err);
1082	return ERR_PTR(err);
1083}
1084
1085static int
1086gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
1087{
1088	struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1089	struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
1090	int err;
1091
1092	do {
1093		err = gss_create_upcall(gss_auth, gss_cred);
1094	} while (err == -EAGAIN);
1095	return err;
1096}
1097
1098static int
1099gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
1100{
1101	struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
1102
1103	if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
1104		goto out;
1105	/* Don't match with creds that have expired. */
1106	if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
1107		return 0;
1108	if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
1109		return 0;
1110out:
1111	if (acred->principal != NULL) {
1112		if (gss_cred->gc_principal == NULL)
1113			return 0;
1114		return strcmp(acred->principal, gss_cred->gc_principal) == 0;
1115	}
1116	if (gss_cred->gc_principal != NULL)
1117		return 0;
1118	return rc->cr_uid == acred->uid;
1119}
1120
1121/*
1122* Marshal credentials.
1123* Maybe we should keep a cached credential for performance reasons.
1124*/
1125static __be32 *
1126gss_marshal(struct rpc_task *task, __be32 *p)
1127{
1128	struct rpc_rqst *req = task->tk_rqstp;
1129	struct rpc_cred *cred = req->rq_cred;
1130	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1131						 gc_base);
1132	struct gss_cl_ctx	*ctx = gss_cred_get_ctx(cred);
1133	__be32		*cred_len;
1134	u32             maj_stat = 0;
1135	struct xdr_netobj mic;
1136	struct kvec	iov;
1137	struct xdr_buf	verf_buf;
1138
1139	dprintk("RPC: %5u %s\n", task->tk_pid, __func__);
1140
1141	*p++ = htonl(RPC_AUTH_GSS);
1142	cred_len = p++;
1143
1144	spin_lock(&ctx->gc_seq_lock);
1145	req->rq_seqno = ctx->gc_seq++;
1146	spin_unlock(&ctx->gc_seq_lock);
1147
1148	*p++ = htonl((u32) RPC_GSS_VERSION);
1149	*p++ = htonl((u32) ctx->gc_proc);
1150	*p++ = htonl((u32) req->rq_seqno);
1151	*p++ = htonl((u32) gss_cred->gc_service);
1152	p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
1153	*cred_len = htonl((p - (cred_len + 1)) << 2);
1154
1155	/* We compute the checksum for the verifier over the xdr-encoded bytes
1156	 * starting with the xid and ending at the end of the credential: */
1157	iov.iov_base = xprt_skip_transport_header(req->rq_xprt,
1158					req->rq_snd_buf.head[0].iov_base);
1159	iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
1160	xdr_buf_from_iov(&iov, &verf_buf);
1161
1162	/* set verifier flavor*/
1163	*p++ = htonl(RPC_AUTH_GSS);
1164
1165	mic.data = (u8 *)(p + 1);
1166	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1167	if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
1168		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1169	} else if (maj_stat != 0) {
1170		printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
1171		goto out_put_ctx;
1172	}
1173	p = xdr_encode_opaque(p, NULL, mic.len);
1174	gss_put_ctx(ctx);
1175	return p;
1176out_put_ctx:
1177	gss_put_ctx(ctx);
1178	return NULL;
1179}
1180
1181static int gss_renew_cred(struct rpc_task *task)
1182{
1183	struct rpc_cred *oldcred = task->tk_rqstp->rq_cred;
1184	struct gss_cred *gss_cred = container_of(oldcred,
1185						 struct gss_cred,
1186						 gc_base);
1187	struct rpc_auth *auth = oldcred->cr_auth;
1188	struct auth_cred acred = {
1189		.uid = oldcred->cr_uid,
1190		.principal = gss_cred->gc_principal,
1191		.machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0),
1192	};
1193	struct rpc_cred *new;
1194
1195	new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
1196	if (IS_ERR(new))
1197		return PTR_ERR(new);
1198	task->tk_rqstp->rq_cred = new;
1199	put_rpccred(oldcred);
1200	return 0;
1201}
1202
1203static int gss_cred_is_negative_entry(struct rpc_cred *cred)
1204{
1205	if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) {
1206		unsigned long now = jiffies;
1207		unsigned long begin, expire;
1208		struct gss_cred *gss_cred;
1209
1210		gss_cred = container_of(cred, struct gss_cred, gc_base);
1211		begin = gss_cred->gc_upcall_timestamp;
1212		expire = begin + gss_expired_cred_retry_delay * HZ;
1213
1214		if (time_in_range_open(now, begin, expire))
1215			return 1;
1216	}
1217	return 0;
1218}
1219
1220/*
1221* Refresh credentials. XXX - finish
1222*/
1223static int
1224gss_refresh(struct rpc_task *task)
1225{
1226	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1227	int ret = 0;
1228
1229	if (gss_cred_is_negative_entry(cred))
1230		return -EKEYEXPIRED;
1231
1232	if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
1233			!test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
1234		ret = gss_renew_cred(task);
1235		if (ret < 0)
1236			goto out;
1237		cred = task->tk_rqstp->rq_cred;
1238	}
1239
1240	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
1241		ret = gss_refresh_upcall(task);
1242out:
1243	return ret;
1244}
1245
1246/* Dummy refresh routine: used only when destroying the context */
1247static int
1248gss_refresh_null(struct rpc_task *task)
1249{
1250	return -EACCES;
1251}
1252
1253static __be32 *
1254gss_validate(struct rpc_task *task, __be32 *p)
1255{
1256	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1257	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1258	__be32		seq;
1259	struct kvec	iov;
1260	struct xdr_buf	verf_buf;
1261	struct xdr_netobj mic;
1262	u32		flav,len;
1263	u32		maj_stat;
1264
1265	dprintk("RPC: %5u %s\n", task->tk_pid, __func__);
1266
1267	flav = ntohl(*p++);
1268	if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
1269		goto out_bad;
1270	if (flav != RPC_AUTH_GSS)
1271		goto out_bad;
1272	seq = htonl(task->tk_rqstp->rq_seqno);
1273	iov.iov_base = &seq;
1274	iov.iov_len = sizeof(seq);
1275	xdr_buf_from_iov(&iov, &verf_buf);
1276	mic.data = (u8 *)p;
1277	mic.len = len;
1278
1279	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1280	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1281		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1282	if (maj_stat) {
1283		dprintk("RPC: %5u %s: gss_verify_mic returned error 0x%08x\n",
1284			task->tk_pid, __func__, maj_stat);
1285		goto out_bad;
1286	}
1287	/* We leave it to unwrap to calculate au_rslack. For now we just
1288	 * calculate the length of the verifier: */
1289	cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
1290	gss_put_ctx(ctx);
1291	dprintk("RPC: %5u %s: gss_verify_mic succeeded.\n",
1292			task->tk_pid, __func__);
1293	return p + XDR_QUADLEN(len);
1294out_bad:
1295	gss_put_ctx(ctx);
1296	dprintk("RPC: %5u %s failed.\n", task->tk_pid, __func__);
1297	return NULL;
1298}
1299
1300static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
1301				__be32 *p, void *obj)
1302{
1303	struct xdr_stream xdr;
1304
1305	xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p);
1306	encode(rqstp, &xdr, obj);
1307}
1308
1309static inline int
1310gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1311		   kxdreproc_t encode, struct rpc_rqst *rqstp,
1312		   __be32 *p, void *obj)
1313{
1314	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1315	struct xdr_buf	integ_buf;
1316	__be32          *integ_len = NULL;
1317	struct xdr_netobj mic;
1318	u32		offset;
1319	__be32		*q;
1320	struct kvec	*iov;
1321	u32             maj_stat = 0;
1322	int		status = -EIO;
1323
1324	integ_len = p++;
1325	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1326	*p++ = htonl(rqstp->rq_seqno);
1327
1328	gss_wrap_req_encode(encode, rqstp, p, obj);
1329
1330	if (xdr_buf_subsegment(snd_buf, &integ_buf,
1331				offset, snd_buf->len - offset))
1332		return status;
1333	*integ_len = htonl(integ_buf.len);
1334
1335	/* guess whether we're in the head or the tail: */
1336	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1337		iov = snd_buf->tail;
1338	else
1339		iov = snd_buf->head;
1340	p = iov->iov_base + iov->iov_len;
1341	mic.data = (u8 *)(p + 1);
1342
1343	maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1344	status = -EIO; /* XXX? */
1345	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1346		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1347	else if (maj_stat)
1348		return status;
1349	q = xdr_encode_opaque(p, NULL, mic.len);
1350
1351	offset = (u8 *)q - (u8 *)p;
1352	iov->iov_len += offset;
1353	snd_buf->len += offset;
1354	return 0;
1355}
1356
1357static void
1358priv_release_snd_buf(struct rpc_rqst *rqstp)
1359{
1360	int i;
1361
1362	for (i=0; i < rqstp->rq_enc_pages_num; i++)
1363		__free_page(rqstp->rq_enc_pages[i]);
1364	kfree(rqstp->rq_enc_pages);
1365}
1366
1367static int
1368alloc_enc_pages(struct rpc_rqst *rqstp)
1369{
1370	struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
1371	int first, last, i;
1372
1373	if (snd_buf->page_len == 0) {
1374		rqstp->rq_enc_pages_num = 0;
1375		return 0;
1376	}
1377
1378	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1379	last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT;
1380	rqstp->rq_enc_pages_num = last - first + 1 + 1;
1381	rqstp->rq_enc_pages
1382		= kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *),
1383				GFP_NOFS);
1384	if (!rqstp->rq_enc_pages)
1385		goto out;
1386	for (i=0; i < rqstp->rq_enc_pages_num; i++) {
1387		rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS);
1388		if (rqstp->rq_enc_pages[i] == NULL)
1389			goto out_free;
1390	}
1391	rqstp->rq_release_snd_buf = priv_release_snd_buf;
1392	return 0;
1393out_free:
1394	rqstp->rq_enc_pages_num = i;
1395	priv_release_snd_buf(rqstp);
1396out:
1397	return -EAGAIN;
1398}
1399
1400static inline int
1401gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1402		  kxdreproc_t encode, struct rpc_rqst *rqstp,
1403		  __be32 *p, void *obj)
1404{
1405	struct xdr_buf	*snd_buf = &rqstp->rq_snd_buf;
1406	u32		offset;
1407	u32             maj_stat;
1408	int		status;
1409	__be32		*opaque_len;
1410	struct page	**inpages;
1411	int		first;
1412	int		pad;
1413	struct kvec	*iov;
1414	char		*tmp;
1415
1416	opaque_len = p++;
1417	offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1418	*p++ = htonl(rqstp->rq_seqno);
1419
1420	gss_wrap_req_encode(encode, rqstp, p, obj);
1421
1422	status = alloc_enc_pages(rqstp);
1423	if (status)
1424		return status;
1425	first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1426	inpages = snd_buf->pages + first;
1427	snd_buf->pages = rqstp->rq_enc_pages;
1428	snd_buf->page_base -= first << PAGE_CACHE_SHIFT;
1429	/*
1430	 * Give the tail its own page, in case we need extra space in the
1431	 * head when wrapping:
1432	 *
1433	 * call_allocate() allocates twice the slack space required
1434	 * by the authentication flavor to rq_callsize.
1435	 * For GSS, slack is GSS_CRED_SLACK.
1436	 */
1437	if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
1438		tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
1439		memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
1440		snd_buf->tail[0].iov_base = tmp;
1441	}
1442	maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
1443	/* slack space should prevent this ever happening: */
1444	BUG_ON(snd_buf->len > snd_buf->buflen);
1445	status = -EIO;
1446	/* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
1447	 * done anyway, so it's safe to put the request on the wire: */
1448	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1449		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1450	else if (maj_stat)
1451		return status;
1452
1453	*opaque_len = htonl(snd_buf->len - offset);
1454	/* guess whether we're in the head or the tail: */
1455	if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1456		iov = snd_buf->tail;
1457	else
1458		iov = snd_buf->head;
1459	p = iov->iov_base + iov->iov_len;
1460	pad = 3 - ((snd_buf->len - offset - 1) & 3);
1461	memset(p, 0, pad);
1462	iov->iov_len += pad;
1463	snd_buf->len += pad;
1464
1465	return 0;
1466}
1467
1468static int
1469gss_wrap_req(struct rpc_task *task,
1470	     kxdreproc_t encode, void *rqstp, __be32 *p, void *obj)
1471{
1472	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1473	struct gss_cred	*gss_cred = container_of(cred, struct gss_cred,
1474			gc_base);
1475	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1476	int             status = -EIO;
1477
1478	dprintk("RPC: %5u %s\n", task->tk_pid, __func__);
1479	if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
1480		/* The spec seems a little ambiguous here, but I think that not
1481		 * wrapping context destruction requests makes the most sense.
1482		 */
1483		gss_wrap_req_encode(encode, rqstp, p, obj);
1484		status = 0;
1485		goto out;
1486	}
1487	switch (gss_cred->gc_service) {
1488	case RPC_GSS_SVC_NONE:
1489		gss_wrap_req_encode(encode, rqstp, p, obj);
1490		status = 0;
1491		break;
1492	case RPC_GSS_SVC_INTEGRITY:
1493		status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj);
1494		break;
1495	case RPC_GSS_SVC_PRIVACY:
1496		status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj);
1497		break;
1498	}
1499out:
1500	gss_put_ctx(ctx);
1501	dprintk("RPC: %5u %s returning %d\n", task->tk_pid, __func__, status);
1502	return status;
1503}
1504
1505static inline int
1506gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1507		struct rpc_rqst *rqstp, __be32 **p)
1508{
1509	struct xdr_buf	*rcv_buf = &rqstp->rq_rcv_buf;
1510	struct xdr_buf integ_buf;
1511	struct xdr_netobj mic;
1512	u32 data_offset, mic_offset;
1513	u32 integ_len;
1514	u32 maj_stat;
1515	int status = -EIO;
1516
1517	integ_len = ntohl(*(*p)++);
1518	if (integ_len & 3)
1519		return status;
1520	data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1521	mic_offset = integ_len + data_offset;
1522	if (mic_offset > rcv_buf->len)
1523		return status;
1524	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1525		return status;
1526
1527	if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
1528				mic_offset - data_offset))
1529		return status;
1530
1531	if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
1532		return status;
1533
1534	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1535	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1536		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1537	if (maj_stat != GSS_S_COMPLETE)
1538		return status;
1539	return 0;
1540}
1541
1542static inline int
1543gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1544		struct rpc_rqst *rqstp, __be32 **p)
1545{
1546	struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1547	u32 offset;
1548	u32 opaque_len;
1549	u32 maj_stat;
1550	int status = -EIO;
1551
1552	opaque_len = ntohl(*(*p)++);
1553	offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1554	if (offset + opaque_len > rcv_buf->len)
1555		return status;
1556	/* remove padding: */
1557	rcv_buf->len = offset + opaque_len;
1558
1559	maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
1560	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1561		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1562	if (maj_stat != GSS_S_COMPLETE)
1563		return status;
1564	if (ntohl(*(*p)++) != rqstp->rq_seqno)
1565		return status;
1566
1567	return 0;
1568}
1569
1570static int
1571gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
1572		      __be32 *p, void *obj)
1573{
1574	struct xdr_stream xdr;
1575
1576	xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p);
1577	return decode(rqstp, &xdr, obj);
1578}
1579
1580static int
1581gss_unwrap_resp(struct rpc_task *task,
1582		kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj)
1583{
1584	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1585	struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1586			gc_base);
1587	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1588	__be32		*savedp = p;
1589	struct kvec	*head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
1590	int		savedlen = head->iov_len;
1591	int             status = -EIO;
1592
1593	if (ctx->gc_proc != RPC_GSS_PROC_DATA)
1594		goto out_decode;
1595	switch (gss_cred->gc_service) {
1596	case RPC_GSS_SVC_NONE:
1597		break;
1598	case RPC_GSS_SVC_INTEGRITY:
1599		status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
1600		if (status)
1601			goto out;
1602		break;
1603	case RPC_GSS_SVC_PRIVACY:
1604		status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
1605		if (status)
1606			goto out;
1607		break;
1608	}
1609	/* take into account extra slack for integrity and privacy cases: */
1610	cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
1611						+ (savedlen - head->iov_len);
1612out_decode:
1613	status = gss_unwrap_req_decode(decode, rqstp, p, obj);
1614out:
1615	gss_put_ctx(ctx);
1616	dprintk("RPC: %5u %s returning %d\n",
1617		task->tk_pid, __func__, status);
1618	return status;
1619}
1620
1621static const struct rpc_authops authgss_ops = {
1622	.owner		= THIS_MODULE,
1623	.au_flavor	= RPC_AUTH_GSS,
1624	.au_name	= "RPCSEC_GSS",
1625	.create		= gss_create,
1626	.destroy	= gss_destroy,
1627	.lookup_cred	= gss_lookup_cred,
1628	.crcreate	= gss_create_cred,
1629	.pipes_create	= gss_pipes_dentries_create,
1630	.pipes_destroy	= gss_pipes_dentries_destroy,
1631	.list_pseudoflavors = gss_mech_list_pseudoflavors,
1632};
1633
1634static const struct rpc_credops gss_credops = {
1635	.cr_name	= "AUTH_GSS",
1636	.crdestroy	= gss_destroy_cred,
1637	.cr_init	= gss_cred_init,
1638	.crbind		= rpcauth_generic_bind_cred,
1639	.crmatch	= gss_match,
1640	.crmarshal	= gss_marshal,
1641	.crrefresh	= gss_refresh,
1642	.crvalidate	= gss_validate,
1643	.crwrap_req	= gss_wrap_req,
1644	.crunwrap_resp	= gss_unwrap_resp,
1645};
1646
1647static const struct rpc_credops gss_nullops = {
1648	.cr_name	= "AUTH_GSS",
1649	.crdestroy	= gss_destroy_nullcred,
1650	.crbind		= rpcauth_generic_bind_cred,
1651	.crmatch	= gss_match,
1652	.crmarshal	= gss_marshal,
1653	.crrefresh	= gss_refresh_null,
1654	.crvalidate	= gss_validate,
1655	.crwrap_req	= gss_wrap_req,
1656	.crunwrap_resp	= gss_unwrap_resp,
1657};
1658
1659static const struct rpc_pipe_ops gss_upcall_ops_v0 = {
1660	.upcall		= rpc_pipe_generic_upcall,
1661	.downcall	= gss_pipe_downcall,
1662	.destroy_msg	= gss_pipe_destroy_msg,
1663	.open_pipe	= gss_pipe_open_v0,
1664	.release_pipe	= gss_pipe_release,
1665};
1666
1667static const struct rpc_pipe_ops gss_upcall_ops_v1 = {
1668	.upcall		= rpc_pipe_generic_upcall,
1669	.downcall	= gss_pipe_downcall,
1670	.destroy_msg	= gss_pipe_destroy_msg,
1671	.open_pipe	= gss_pipe_open_v1,
1672	.release_pipe	= gss_pipe_release,
1673};
1674
1675static __net_init int rpcsec_gss_init_net(struct net *net)
1676{
1677	return gss_svc_init_net(net);
1678}
1679
1680static __net_exit void rpcsec_gss_exit_net(struct net *net)
1681{
1682	gss_svc_shutdown_net(net);
1683}
1684
1685static struct pernet_operations rpcsec_gss_net_ops = {
1686	.init = rpcsec_gss_init_net,
1687	.exit = rpcsec_gss_exit_net,
1688};
1689
1690/*
1691 * Initialize RPCSEC_GSS module
1692 */
1693static int __init init_rpcsec_gss(void)
1694{
1695	int err = 0;
1696
1697	err = rpcauth_register(&authgss_ops);
1698	if (err)
1699		goto out;
1700	err = gss_svc_init();
1701	if (err)
1702		goto out_unregister;
1703	err = register_pernet_subsys(&rpcsec_gss_net_ops);
1704	if (err)
1705		goto out_svc_exit;
1706	rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
1707	return 0;
1708out_svc_exit:
1709	gss_svc_shutdown();
1710out_unregister:
1711	rpcauth_unregister(&authgss_ops);
1712out:
1713	return err;
1714}
1715
1716static void __exit exit_rpcsec_gss(void)
1717{
1718	unregister_pernet_subsys(&rpcsec_gss_net_ops);
1719	gss_svc_shutdown();
1720	rpcauth_unregister(&authgss_ops);
1721	rcu_barrier(); /* Wait for completion of call_rcu()'s */
1722}
1723
1724MODULE_LICENSE("GPL");
1725module_param_named(expired_cred_retry_delay,
1726		   gss_expired_cred_retry_delay,
1727		   uint, 0644);
1728MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until "
1729		"the RPC engine retries an expired credential");
1730
1731module_init(init_rpcsec_gss)
1732module_exit(exit_rpcsec_gss)
1733