1/* 2 * linux/net/sunrpc/auth.c 3 * 4 * Generic RPC client authentication API. 5 * 6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> 7 */ 8 9#include <linux/types.h> 10#include <linux/sched.h> 11#include <linux/module.h> 12#include <linux/slab.h> 13#include <linux/errno.h> 14#include <linux/hash.h> 15#include <linux/sunrpc/clnt.h> 16#include <linux/sunrpc/gss_api.h> 17#include <linux/spinlock.h> 18 19#ifdef RPC_DEBUG 20# define RPCDBG_FACILITY RPCDBG_AUTH 21#endif 22 23#define RPC_CREDCACHE_DEFAULT_HASHBITS (4) 24struct rpc_cred_cache { 25 struct hlist_head *hashtable; 26 unsigned int hashbits; 27 spinlock_t lock; 28}; 29 30static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; 31 32static DEFINE_SPINLOCK(rpc_authflavor_lock); 33static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { 34 &authnull_ops, /* AUTH_NULL */ 35 &authunix_ops, /* AUTH_UNIX */ 36 NULL, /* others can be loadable modules */ 37}; 38 39static LIST_HEAD(cred_unused); 40static unsigned long number_cred_unused; 41 42#define MAX_HASHTABLE_BITS (14) 43static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp) 44{ 45 unsigned long num; 46 unsigned int nbits; 47 int ret; 48 49 if (!val) 50 goto out_inval; 51 ret = strict_strtoul(val, 0, &num); 52 if (ret == -EINVAL) 53 goto out_inval; 54 nbits = fls(num); 55 if (num > (1U << nbits)) 56 nbits++; 57 if (nbits > MAX_HASHTABLE_BITS || nbits < 2) 58 goto out_inval; 59 *(unsigned int *)kp->arg = nbits; 60 return 0; 61out_inval: 62 return -EINVAL; 63} 64 65static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp) 66{ 67 unsigned int nbits; 68 69 nbits = *(unsigned int *)kp->arg; 70 return sprintf(buffer, "%u", 1U << nbits); 71} 72 73#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int); 74 75static struct kernel_param_ops param_ops_hashtbl_sz = { 76 .set = param_set_hashtbl_sz, 77 .get = param_get_hashtbl_sz, 78}; 79 80module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644); 81MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size"); 82 83static u32 84pseudoflavor_to_flavor(u32 flavor) { 85 if (flavor > RPC_AUTH_MAXFLAVOR) 86 return RPC_AUTH_GSS; 87 return flavor; 88} 89 90int 91rpcauth_register(const struct rpc_authops *ops) 92{ 93 rpc_authflavor_t flavor; 94 int ret = -EPERM; 95 96 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 97 return -EINVAL; 98 spin_lock(&rpc_authflavor_lock); 99 if (auth_flavors[flavor] == NULL) { 100 auth_flavors[flavor] = ops; 101 ret = 0; 102 } 103 spin_unlock(&rpc_authflavor_lock); 104 return ret; 105} 106EXPORT_SYMBOL_GPL(rpcauth_register); 107 108int 109rpcauth_unregister(const struct rpc_authops *ops) 110{ 111 rpc_authflavor_t flavor; 112 int ret = -EPERM; 113 114 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 115 return -EINVAL; 116 spin_lock(&rpc_authflavor_lock); 117 if (auth_flavors[flavor] == ops) { 118 auth_flavors[flavor] = NULL; 119 ret = 0; 120 } 121 spin_unlock(&rpc_authflavor_lock); 122 return ret; 123} 124EXPORT_SYMBOL_GPL(rpcauth_unregister); 125 126/** 127 * rpcauth_get_pseudoflavor - check if security flavor is supported 128 * @flavor: a security flavor 129 * @info: a GSS mech OID, quality of protection, and service value 130 * 131 * Verifies that an appropriate kernel module is available or already loaded. 132 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is 133 * not supported locally. 134 */ 135rpc_authflavor_t 136rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) 137{ 138 const struct rpc_authops *ops; 139 rpc_authflavor_t pseudoflavor; 140 141 ops = auth_flavors[flavor]; 142 if (ops == NULL) 143 request_module("rpc-auth-%u", flavor); 144 spin_lock(&rpc_authflavor_lock); 145 ops = auth_flavors[flavor]; 146 if (ops == NULL || !try_module_get(ops->owner)) { 147 spin_unlock(&rpc_authflavor_lock); 148 return RPC_AUTH_MAXFLAVOR; 149 } 150 spin_unlock(&rpc_authflavor_lock); 151 152 pseudoflavor = flavor; 153 if (ops->info2flavor != NULL) 154 pseudoflavor = ops->info2flavor(info); 155 156 module_put(ops->owner); 157 return pseudoflavor; 158} 159EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); 160 161/** 162 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor 163 * @pseudoflavor: GSS pseudoflavor to match 164 * @info: rpcsec_gss_info structure to fill in 165 * 166 * Returns zero and fills in "info" if pseudoflavor matches a 167 * supported mechanism. 168 */ 169int 170rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) 171{ 172 rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor); 173 const struct rpc_authops *ops; 174 int result; 175 176 if (flavor >= RPC_AUTH_MAXFLAVOR) 177 return -EINVAL; 178 179 ops = auth_flavors[flavor]; 180 if (ops == NULL) 181 request_module("rpc-auth-%u", flavor); 182 spin_lock(&rpc_authflavor_lock); 183 ops = auth_flavors[flavor]; 184 if (ops == NULL || !try_module_get(ops->owner)) { 185 spin_unlock(&rpc_authflavor_lock); 186 return -ENOENT; 187 } 188 spin_unlock(&rpc_authflavor_lock); 189 190 result = -ENOENT; 191 if (ops->flavor2info != NULL) 192 result = ops->flavor2info(pseudoflavor, info); 193 194 module_put(ops->owner); 195 return result; 196} 197EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); 198 199/** 200 * rpcauth_list_flavors - discover registered flavors and pseudoflavors 201 * @array: array to fill in 202 * @size: size of "array" 203 * 204 * Returns the number of array items filled in, or a negative errno. 205 * 206 * The returned array is not sorted by any policy. Callers should not 207 * rely on the order of the items in the returned array. 208 */ 209int 210rpcauth_list_flavors(rpc_authflavor_t *array, int size) 211{ 212 rpc_authflavor_t flavor; 213 int result = 0; 214 215 spin_lock(&rpc_authflavor_lock); 216 for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) { 217 const struct rpc_authops *ops = auth_flavors[flavor]; 218 rpc_authflavor_t pseudos[4]; 219 int i, len; 220 221 if (result >= size) { 222 result = -ENOMEM; 223 break; 224 } 225 226 if (ops == NULL) 227 continue; 228 if (ops->list_pseudoflavors == NULL) { 229 array[result++] = ops->au_flavor; 230 continue; 231 } 232 len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos)); 233 if (len < 0) { 234 result = len; 235 break; 236 } 237 for (i = 0; i < len; i++) { 238 if (result >= size) { 239 result = -ENOMEM; 240 break; 241 } 242 array[result++] = pseudos[i]; 243 } 244 } 245 spin_unlock(&rpc_authflavor_lock); 246 247 dprintk("RPC: %s returns %d\n", __func__, result); 248 return result; 249} 250EXPORT_SYMBOL_GPL(rpcauth_list_flavors); 251 252struct rpc_auth * 253rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) 254{ 255 struct rpc_auth *auth; 256 const struct rpc_authops *ops; 257 u32 flavor = pseudoflavor_to_flavor(pseudoflavor); 258 259 auth = ERR_PTR(-EINVAL); 260 if (flavor >= RPC_AUTH_MAXFLAVOR) 261 goto out; 262 263 if ((ops = auth_flavors[flavor]) == NULL) 264 request_module("rpc-auth-%u", flavor); 265 spin_lock(&rpc_authflavor_lock); 266 ops = auth_flavors[flavor]; 267 if (ops == NULL || !try_module_get(ops->owner)) { 268 spin_unlock(&rpc_authflavor_lock); 269 goto out; 270 } 271 spin_unlock(&rpc_authflavor_lock); 272 auth = ops->create(clnt, pseudoflavor); 273 module_put(ops->owner); 274 if (IS_ERR(auth)) 275 return auth; 276 if (clnt->cl_auth) 277 rpcauth_release(clnt->cl_auth); 278 clnt->cl_auth = auth; 279 280out: 281 return auth; 282} 283EXPORT_SYMBOL_GPL(rpcauth_create); 284 285void 286rpcauth_release(struct rpc_auth *auth) 287{ 288 if (!atomic_dec_and_test(&auth->au_count)) 289 return; 290 auth->au_ops->destroy(auth); 291} 292 293static DEFINE_SPINLOCK(rpc_credcache_lock); 294 295static void 296rpcauth_unhash_cred_locked(struct rpc_cred *cred) 297{ 298 hlist_del_rcu(&cred->cr_hash); 299 smp_mb__before_clear_bit(); 300 clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); 301} 302 303static int 304rpcauth_unhash_cred(struct rpc_cred *cred) 305{ 306 spinlock_t *cache_lock; 307 int ret; 308 309 cache_lock = &cred->cr_auth->au_credcache->lock; 310 spin_lock(cache_lock); 311 ret = atomic_read(&cred->cr_count) == 0; 312 if (ret) 313 rpcauth_unhash_cred_locked(cred); 314 spin_unlock(cache_lock); 315 return ret; 316} 317 318/* 319 * Initialize RPC credential cache 320 */ 321int 322rpcauth_init_credcache(struct rpc_auth *auth) 323{ 324 struct rpc_cred_cache *new; 325 unsigned int hashsize; 326 327 new = kmalloc(sizeof(*new), GFP_KERNEL); 328 if (!new) 329 goto out_nocache; 330 new->hashbits = auth_hashbits; 331 hashsize = 1U << new->hashbits; 332 new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL); 333 if (!new->hashtable) 334 goto out_nohashtbl; 335 spin_lock_init(&new->lock); 336 auth->au_credcache = new; 337 return 0; 338out_nohashtbl: 339 kfree(new); 340out_nocache: 341 return -ENOMEM; 342} 343EXPORT_SYMBOL_GPL(rpcauth_init_credcache); 344 345/* 346 * Destroy a list of credentials 347 */ 348static inline 349void rpcauth_destroy_credlist(struct list_head *head) 350{ 351 struct rpc_cred *cred; 352 353 while (!list_empty(head)) { 354 cred = list_entry(head->next, struct rpc_cred, cr_lru); 355 list_del_init(&cred->cr_lru); 356 put_rpccred(cred); 357 } 358} 359 360/* 361 * Clear the RPC credential cache, and delete those credentials 362 * that are not referenced. 363 */ 364void 365rpcauth_clear_credcache(struct rpc_cred_cache *cache) 366{ 367 LIST_HEAD(free); 368 struct hlist_head *head; 369 struct rpc_cred *cred; 370 unsigned int hashsize = 1U << cache->hashbits; 371 int i; 372 373 spin_lock(&rpc_credcache_lock); 374 spin_lock(&cache->lock); 375 for (i = 0; i < hashsize; i++) { 376 head = &cache->hashtable[i]; 377 while (!hlist_empty(head)) { 378 cred = hlist_entry(head->first, struct rpc_cred, cr_hash); 379 get_rpccred(cred); 380 if (!list_empty(&cred->cr_lru)) { 381 list_del(&cred->cr_lru); 382 number_cred_unused--; 383 } 384 list_add_tail(&cred->cr_lru, &free); 385 rpcauth_unhash_cred_locked(cred); 386 } 387 } 388 spin_unlock(&cache->lock); 389 spin_unlock(&rpc_credcache_lock); 390 rpcauth_destroy_credlist(&free); 391} 392 393/* 394 * Destroy the RPC credential cache 395 */ 396void 397rpcauth_destroy_credcache(struct rpc_auth *auth) 398{ 399 struct rpc_cred_cache *cache = auth->au_credcache; 400 401 if (cache) { 402 auth->au_credcache = NULL; 403 rpcauth_clear_credcache(cache); 404 kfree(cache->hashtable); 405 kfree(cache); 406 } 407} 408EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); 409 410 411#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ) 412 413/* 414 * Remove stale credentials. Avoid sleeping inside the loop. 415 */ 416static int 417rpcauth_prune_expired(struct list_head *free, int nr_to_scan) 418{ 419 spinlock_t *cache_lock; 420 struct rpc_cred *cred, *next; 421 unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; 422 423 list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { 424 425 if (nr_to_scan-- == 0) 426 break; 427 /* 428 * Enforce a 60 second garbage collection moratorium 429 * Note that the cred_unused list must be time-ordered. 430 */ 431 if (time_in_range(cred->cr_expire, expired, jiffies) && 432 test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) 433 return 0; 434 435 list_del_init(&cred->cr_lru); 436 number_cred_unused--; 437 if (atomic_read(&cred->cr_count) != 0) 438 continue; 439 440 cache_lock = &cred->cr_auth->au_credcache->lock; 441 spin_lock(cache_lock); 442 if (atomic_read(&cred->cr_count) == 0) { 443 get_rpccred(cred); 444 list_add_tail(&cred->cr_lru, free); 445 rpcauth_unhash_cred_locked(cred); 446 } 447 spin_unlock(cache_lock); 448 } 449 return (number_cred_unused / 100) * sysctl_vfs_cache_pressure; 450} 451 452/* 453 * Run memory cache shrinker. 454 */ 455static int 456rpcauth_cache_shrinker(struct shrinker *shrink, struct shrink_control *sc) 457{ 458 LIST_HEAD(free); 459 int res; 460 int nr_to_scan = sc->nr_to_scan; 461 gfp_t gfp_mask = sc->gfp_mask; 462 463 if ((gfp_mask & GFP_KERNEL) != GFP_KERNEL) 464 return (nr_to_scan == 0) ? 0 : -1; 465 if (list_empty(&cred_unused)) 466 return 0; 467 spin_lock(&rpc_credcache_lock); 468 res = rpcauth_prune_expired(&free, nr_to_scan); 469 spin_unlock(&rpc_credcache_lock); 470 rpcauth_destroy_credlist(&free); 471 return res; 472} 473 474/* 475 * Look up a process' credentials in the authentication cache 476 */ 477struct rpc_cred * 478rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, 479 int flags) 480{ 481 LIST_HEAD(free); 482 struct rpc_cred_cache *cache = auth->au_credcache; 483 struct rpc_cred *cred = NULL, 484 *entry, *new; 485 unsigned int nr; 486 487 nr = hash_long(from_kuid(&init_user_ns, acred->uid), cache->hashbits); 488 489 rcu_read_lock(); 490 hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) { 491 if (!entry->cr_ops->crmatch(acred, entry, flags)) 492 continue; 493 spin_lock(&cache->lock); 494 if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) { 495 spin_unlock(&cache->lock); 496 continue; 497 } 498 cred = get_rpccred(entry); 499 spin_unlock(&cache->lock); 500 break; 501 } 502 rcu_read_unlock(); 503 504 if (cred != NULL) 505 goto found; 506 507 new = auth->au_ops->crcreate(auth, acred, flags); 508 if (IS_ERR(new)) { 509 cred = new; 510 goto out; 511 } 512 513 spin_lock(&cache->lock); 514 hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) { 515 if (!entry->cr_ops->crmatch(acred, entry, flags)) 516 continue; 517 cred = get_rpccred(entry); 518 break; 519 } 520 if (cred == NULL) { 521 cred = new; 522 set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); 523 hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]); 524 } else 525 list_add_tail(&new->cr_lru, &free); 526 spin_unlock(&cache->lock); 527found: 528 if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && 529 cred->cr_ops->cr_init != NULL && 530 !(flags & RPCAUTH_LOOKUP_NEW)) { 531 int res = cred->cr_ops->cr_init(auth, cred); 532 if (res < 0) { 533 put_rpccred(cred); 534 cred = ERR_PTR(res); 535 } 536 } 537 rpcauth_destroy_credlist(&free); 538out: 539 return cred; 540} 541EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache); 542 543struct rpc_cred * 544rpcauth_lookupcred(struct rpc_auth *auth, int flags) 545{ 546 struct auth_cred acred; 547 struct rpc_cred *ret; 548 const struct cred *cred = current_cred(); 549 550 dprintk("RPC: looking up %s cred\n", 551 auth->au_ops->au_name); 552 553 memset(&acred, 0, sizeof(acred)); 554 acred.uid = cred->fsuid; 555 acred.gid = cred->fsgid; 556 acred.group_info = get_group_info(((struct cred *)cred)->group_info); 557 558 ret = auth->au_ops->lookup_cred(auth, &acred, flags); 559 put_group_info(acred.group_info); 560 return ret; 561} 562 563void 564rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, 565 struct rpc_auth *auth, const struct rpc_credops *ops) 566{ 567 INIT_HLIST_NODE(&cred->cr_hash); 568 INIT_LIST_HEAD(&cred->cr_lru); 569 atomic_set(&cred->cr_count, 1); 570 cred->cr_auth = auth; 571 cred->cr_ops = ops; 572 cred->cr_expire = jiffies; 573#ifdef RPC_DEBUG 574 cred->cr_magic = RPCAUTH_CRED_MAGIC; 575#endif 576 cred->cr_uid = acred->uid; 577} 578EXPORT_SYMBOL_GPL(rpcauth_init_cred); 579 580struct rpc_cred * 581rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags) 582{ 583 dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid, 584 cred->cr_auth->au_ops->au_name, cred); 585 return get_rpccred(cred); 586} 587EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred); 588 589static struct rpc_cred * 590rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) 591{ 592 struct rpc_auth *auth = task->tk_client->cl_auth; 593 struct auth_cred acred = { 594 .uid = GLOBAL_ROOT_UID, 595 .gid = GLOBAL_ROOT_GID, 596 }; 597 598 dprintk("RPC: %5u looking up %s cred\n", 599 task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); 600 return auth->au_ops->lookup_cred(auth, &acred, lookupflags); 601} 602 603static struct rpc_cred * 604rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags) 605{ 606 struct rpc_auth *auth = task->tk_client->cl_auth; 607 608 dprintk("RPC: %5u looking up %s cred\n", 609 task->tk_pid, auth->au_ops->au_name); 610 return rpcauth_lookupcred(auth, lookupflags); 611} 612 613static int 614rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags) 615{ 616 struct rpc_rqst *req = task->tk_rqstp; 617 struct rpc_cred *new; 618 int lookupflags = 0; 619 620 if (flags & RPC_TASK_ASYNC) 621 lookupflags |= RPCAUTH_LOOKUP_NEW; 622 if (cred != NULL) 623 new = cred->cr_ops->crbind(task, cred, lookupflags); 624 else if (flags & RPC_TASK_ROOTCREDS) 625 new = rpcauth_bind_root_cred(task, lookupflags); 626 else 627 new = rpcauth_bind_new_cred(task, lookupflags); 628 if (IS_ERR(new)) 629 return PTR_ERR(new); 630 if (req->rq_cred != NULL) 631 put_rpccred(req->rq_cred); 632 req->rq_cred = new; 633 return 0; 634} 635 636void 637put_rpccred(struct rpc_cred *cred) 638{ 639 /* Fast path for unhashed credentials */ 640 if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) { 641 if (atomic_dec_and_test(&cred->cr_count)) 642 cred->cr_ops->crdestroy(cred); 643 return; 644 } 645 646 if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock)) 647 return; 648 if (!list_empty(&cred->cr_lru)) { 649 number_cred_unused--; 650 list_del_init(&cred->cr_lru); 651 } 652 if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { 653 if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { 654 cred->cr_expire = jiffies; 655 list_add_tail(&cred->cr_lru, &cred_unused); 656 number_cred_unused++; 657 goto out_nodestroy; 658 } 659 if (!rpcauth_unhash_cred(cred)) { 660 /* We were hashed and someone looked us up... */ 661 goto out_nodestroy; 662 } 663 } 664 spin_unlock(&rpc_credcache_lock); 665 cred->cr_ops->crdestroy(cred); 666 return; 667out_nodestroy: 668 spin_unlock(&rpc_credcache_lock); 669} 670EXPORT_SYMBOL_GPL(put_rpccred); 671 672__be32 * 673rpcauth_marshcred(struct rpc_task *task, __be32 *p) 674{ 675 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 676 677 dprintk("RPC: %5u marshaling %s cred %p\n", 678 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 679 680 return cred->cr_ops->crmarshal(task, p); 681} 682 683__be32 * 684rpcauth_checkverf(struct rpc_task *task, __be32 *p) 685{ 686 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 687 688 dprintk("RPC: %5u validating %s cred %p\n", 689 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 690 691 return cred->cr_ops->crvalidate(task, p); 692} 693 694static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, 695 __be32 *data, void *obj) 696{ 697 struct xdr_stream xdr; 698 699 xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data); 700 encode(rqstp, &xdr, obj); 701} 702 703int 704rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp, 705 __be32 *data, void *obj) 706{ 707 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 708 709 dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", 710 task->tk_pid, cred->cr_ops->cr_name, cred); 711 if (cred->cr_ops->crwrap_req) 712 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); 713 /* By default, we encode the arguments normally. */ 714 rpcauth_wrap_req_encode(encode, rqstp, data, obj); 715 return 0; 716} 717 718static int 719rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, 720 __be32 *data, void *obj) 721{ 722 struct xdr_stream xdr; 723 724 xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data); 725 return decode(rqstp, &xdr, obj); 726} 727 728int 729rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, 730 __be32 *data, void *obj) 731{ 732 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 733 734 dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", 735 task->tk_pid, cred->cr_ops->cr_name, cred); 736 if (cred->cr_ops->crunwrap_resp) 737 return cred->cr_ops->crunwrap_resp(task, decode, rqstp, 738 data, obj); 739 /* By default, we decode the arguments normally. */ 740 return rpcauth_unwrap_req_decode(decode, rqstp, data, obj); 741} 742 743int 744rpcauth_refreshcred(struct rpc_task *task) 745{ 746 struct rpc_cred *cred; 747 int err; 748 749 cred = task->tk_rqstp->rq_cred; 750 if (cred == NULL) { 751 err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags); 752 if (err < 0) 753 goto out; 754 cred = task->tk_rqstp->rq_cred; 755 } 756 dprintk("RPC: %5u refreshing %s cred %p\n", 757 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 758 759 err = cred->cr_ops->crrefresh(task); 760out: 761 if (err < 0) 762 task->tk_status = err; 763 return err; 764} 765 766void 767rpcauth_invalcred(struct rpc_task *task) 768{ 769 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 770 771 dprintk("RPC: %5u invalidating %s cred %p\n", 772 task->tk_pid, cred->cr_auth->au_ops->au_name, cred); 773 if (cred) 774 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); 775} 776 777int 778rpcauth_uptodatecred(struct rpc_task *task) 779{ 780 struct rpc_cred *cred = task->tk_rqstp->rq_cred; 781 782 return cred == NULL || 783 test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0; 784} 785 786static struct shrinker rpc_cred_shrinker = { 787 .shrink = rpcauth_cache_shrinker, 788 .seeks = DEFAULT_SEEKS, 789}; 790 791int __init rpcauth_init_module(void) 792{ 793 int err; 794 795 err = rpc_init_authunix(); 796 if (err < 0) 797 goto out1; 798 err = rpc_init_generic_auth(); 799 if (err < 0) 800 goto out2; 801 register_shrinker(&rpc_cred_shrinker); 802 return 0; 803out2: 804 rpc_destroy_authunix(); 805out1: 806 return err; 807} 808 809void rpcauth_remove_module(void) 810{ 811 rpc_destroy_authunix(); 812 rpc_destroy_generic_auth(); 813 unregister_shrinker(&rpc_cred_shrinker); 814} 815