1// Copyright (c) 2008 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4// Written in NSPR style to also be suitable for adding to the NSS demo suite
5
6/* memio is a simple NSPR I/O layer that lets you decouple NSS from
7 * the real network.  It's rather like openssl's memory bio,
8 * and is useful when your app absolutely, positively doesn't
9 * want to let NSS do its own networking.
10 */
11
12#include <stdlib.h>
13#include <string.h>
14
15#include <prerror.h>
16#include <prinit.h>
17#include <prlog.h>
18
19#include "nss_memio.h"
20
21/*--------------- private memio types -----------------------*/
22
23/*----------------------------------------------------------------------
24 Simple private circular buffer class.  Size cannot be changed once allocated.
25----------------------------------------------------------------------*/
26
27struct memio_buffer {
28    int head;     /* where to take next byte out of buf */
29    int tail;     /* where to put next byte into buf */
30    int bufsize;  /* number of bytes allocated to buf */
31    /* TODO(port): error handling is pessimistic right now.
32     * Once an error is set, the socket is considered broken
33     * (PR_WOULD_BLOCK_ERROR not included).
34     */
35    PRErrorCode last_err;
36    char *buf;
37};
38
39
40/* The 'secret' field of a PRFileDesc created by memio_CreateIOLayer points
41 * to one of these.
42 * In the public header, we use struct memio_Private as a typesafe alias
43 * for this.  This causes a few ugly typecasts in the private file, but
44 * seems safer.
45 */
46struct PRFilePrivate {
47    /* read requests are satisfied from this buffer */
48    struct memio_buffer readbuf;
49
50    /* write requests are satisfied from this buffer */
51    struct memio_buffer writebuf;
52
53    /* SSL needs to know socket peer's name */
54    PRNetAddr peername;
55
56    /* if set, empty I/O returns EOF instead of EWOULDBLOCK */
57    int eof;
58};
59
60/*--------------- private memio_buffer functions ---------------------*/
61
62/* Forward declarations.  */
63
64/* Allocate a memio_buffer of given size. */
65static void memio_buffer_new(struct memio_buffer *mb, int size);
66
67/* Deallocate a memio_buffer allocated by memio_buffer_new. */
68static void memio_buffer_destroy(struct memio_buffer *mb);
69
70/* How many bytes can be read out of the buffer without wrapping */
71static int memio_buffer_used_contiguous(const struct memio_buffer *mb);
72
73/* How many bytes exist after the wrap? */
74static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb);
75
76/* How many bytes can be written into the buffer without wrapping */
77static int memio_buffer_unused_contiguous(const struct memio_buffer *mb);
78
79/* Write n bytes into the buffer.  Returns number of bytes written. */
80static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n);
81
82/* Read n bytes from the buffer.  Returns number of bytes read. */
83static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n);
84
85/* Allocate a memio_buffer of given size. */
86static void memio_buffer_new(struct memio_buffer *mb, int size)
87{
88    mb->head = 0;
89    mb->tail = 0;
90    mb->bufsize = size;
91    mb->buf = malloc(size);
92}
93
94/* Deallocate a memio_buffer allocated by memio_buffer_new. */
95static void memio_buffer_destroy(struct memio_buffer *mb)
96{
97    free(mb->buf);
98    mb->buf = NULL;
99    mb->head = 0;
100    mb->tail = 0;
101}
102
103/* How many bytes can be read out of the buffer without wrapping */
104static int memio_buffer_used_contiguous(const struct memio_buffer *mb)
105{
106    return (((mb->tail >= mb->head) ? mb->tail : mb->bufsize) - mb->head);
107}
108
109/* How many bytes exist after the wrap? */
110static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb)
111{
112    return (mb->tail >= mb->head) ? 0 : mb->tail;
113}
114
115/* How many bytes can be written into the buffer without wrapping */
116static int memio_buffer_unused_contiguous(const struct memio_buffer *mb)
117{
118    if (mb->head > mb->tail) return mb->head - mb->tail - 1;
119    return mb->bufsize - mb->tail - (mb->head == 0);
120}
121
122/* Write n bytes into the buffer.  Returns number of bytes written. */
123static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n)
124{
125    int len;
126    int transferred = 0;
127
128    /* Handle part before wrap */
129    len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
130    if (len > 0) {
131        /* Buffer not full */
132        memcpy(&mb->buf[mb->tail], buf, len);
133        mb->tail += len;
134        if (mb->tail == mb->bufsize)
135            mb->tail = 0;
136        n -= len;
137        buf += len;
138        transferred += len;
139
140        /* Handle part after wrap */
141        len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
142        if (len > 0) {
143            /* Output buffer still not full, input buffer still not empty */
144            memcpy(&mb->buf[mb->tail], buf, len);
145            mb->tail += len;
146            if (mb->tail == mb->bufsize)
147                mb->tail = 0;
148                transferred += len;
149        }
150    }
151
152    return transferred;
153}
154
155
156/* Read n bytes from the buffer.  Returns number of bytes read. */
157static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n)
158{
159    int len;
160    int transferred = 0;
161
162    /* Handle part before wrap */
163    len = PR_MIN(n, memio_buffer_used_contiguous(mb));
164    if (len) {
165        memcpy(buf, &mb->buf[mb->head], len);
166        mb->head += len;
167        if (mb->head == mb->bufsize)
168            mb->head = 0;
169        n -= len;
170        buf += len;
171        transferred += len;
172
173        /* Handle part after wrap */
174        len = PR_MIN(n, memio_buffer_used_contiguous(mb));
175        if (len) {
176        memcpy(buf, &mb->buf[mb->head], len);
177        mb->head += len;
178            if (mb->head == mb->bufsize)
179                mb->head = 0;
180                transferred += len;
181        }
182    }
183
184    return transferred;
185}
186
187/*--------------- private memio functions -----------------------*/
188
189static PRStatus PR_CALLBACK memio_Close(PRFileDesc *fd)
190{
191    struct PRFilePrivate *secret = fd->secret;
192    memio_buffer_destroy(&secret->readbuf);
193    memio_buffer_destroy(&secret->writebuf);
194    free(secret);
195    fd->dtor(fd);
196    return PR_SUCCESS;
197}
198
199static PRStatus PR_CALLBACK memio_Shutdown(PRFileDesc *fd, PRIntn how)
200{
201    /* TODO: pass shutdown status to app somehow */
202    return PR_SUCCESS;
203}
204
205/* If there was a network error in the past taking bytes
206 * out of the buffer, return it to the next call that
207 * tries to read from an empty buffer.
208 */
209static int PR_CALLBACK memio_Recv(PRFileDesc *fd, void *buf, PRInt32 len,
210                                  PRIntn flags, PRIntervalTime timeout)
211{
212    struct PRFilePrivate *secret;
213    struct memio_buffer *mb;
214    int rv;
215
216    if (flags) {
217        PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
218        return -1;
219    }
220
221    secret = fd->secret;
222    mb = &secret->readbuf;
223    PR_ASSERT(mb->bufsize);
224    rv = memio_buffer_get(mb, buf, len);
225    if (rv == 0 && !secret->eof) {
226        if (mb->last_err)
227            PR_SetError(mb->last_err, 0);
228        else
229            PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
230        return -1;
231    }
232
233    return rv;
234}
235
236static int PR_CALLBACK memio_Read(PRFileDesc *fd, void *buf, PRInt32 len)
237{
238    /* pull bytes from buffer */
239    return memio_Recv(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
240}
241
242static int PR_CALLBACK memio_Send(PRFileDesc *fd, const void *buf, PRInt32 len,
243                                  PRIntn flags, PRIntervalTime timeout)
244{
245    struct PRFilePrivate *secret;
246    struct memio_buffer *mb;
247    int rv;
248
249    secret = fd->secret;
250    mb = &secret->writebuf;
251    PR_ASSERT(mb->bufsize);
252
253    if (mb->last_err) {
254        PR_SetError(mb->last_err, 0);
255        return -1;
256    }
257    rv = memio_buffer_put(mb, buf, len);
258    if (rv == 0) {
259        PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
260        return -1;
261    }
262    return rv;
263}
264
265static int PR_CALLBACK memio_Write(PRFileDesc *fd, const void *buf, PRInt32 len)
266{
267    /* append bytes to buffer */
268    return memio_Send(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
269}
270
271static PRStatus PR_CALLBACK memio_GetPeerName(PRFileDesc *fd, PRNetAddr *addr)
272{
273    /* TODO: fail if memio_SetPeerName has not been called */
274    struct PRFilePrivate *secret = fd->secret;
275    *addr = secret->peername;
276    return PR_SUCCESS;
277}
278
279static PRStatus memio_GetSocketOption(PRFileDesc *fd, PRSocketOptionData *data)
280{
281    /*
282     * Even in the original version for real tcp sockets,
283     * PR_SockOpt_Nonblocking is a special case that does not
284     * translate to a getsockopt() call
285     */
286    if (PR_SockOpt_Nonblocking == data->option) {
287        data->value.non_blocking = PR_TRUE;
288        return PR_SUCCESS;
289    }
290    PR_SetError(PR_OPERATION_NOT_SUPPORTED_ERROR, 0);
291    return PR_FAILURE;
292}
293
294/*--------------- private memio data -----------------------*/
295
296/*
297 * Implement just the bare minimum number of methods needed to make ssl happy.
298 *
299 * Oddly, PR_Recv calls ssl_Recv calls ssl_SocketIsBlocking calls
300 * PR_GetSocketOption, so we have to provide an implementation of
301 * PR_GetSocketOption that just says "I'm nonblocking".
302 */
303
304static struct PRIOMethods  memio_layer_methods = {
305    PR_DESC_LAYERED,
306    memio_Close,
307    memio_Read,
308    memio_Write,
309    NULL,
310    NULL,
311    NULL,
312    NULL,
313    NULL,
314    NULL,
315    NULL,
316    NULL,
317    NULL,
318    NULL,
319    NULL,
320    NULL,
321    memio_Shutdown,
322    memio_Recv,
323    memio_Send,
324    NULL,
325    NULL,
326    NULL,
327    NULL,
328    NULL,
329    NULL,
330    memio_GetPeerName,
331    NULL,
332    NULL,
333    memio_GetSocketOption,
334    NULL,
335    NULL,
336    NULL,
337    NULL,
338    NULL,
339    NULL,
340    NULL,
341};
342
343static PRDescIdentity memio_identity = PR_INVALID_IO_LAYER;
344
345static PRStatus memio_InitializeLayerName(void)
346{
347    memio_identity = PR_GetUniqueIdentity("memio");
348    return PR_SUCCESS;
349}
350
351/*--------------- public memio functions -----------------------*/
352
353PRFileDesc *memio_CreateIOLayer(int bufsize)
354{
355    PRFileDesc *fd;
356    struct PRFilePrivate *secret;
357    static PRCallOnceType once;
358
359    PR_CallOnce(&once, memio_InitializeLayerName);
360
361    fd = PR_CreateIOLayerStub(memio_identity, &memio_layer_methods);
362    secret = malloc(sizeof(struct PRFilePrivate));
363    memset(secret, 0, sizeof(*secret));
364
365    memio_buffer_new(&secret->readbuf, bufsize);
366    memio_buffer_new(&secret->writebuf, bufsize);
367    fd->secret = secret;
368    return fd;
369}
370
371void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername)
372{
373    PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
374    struct PRFilePrivate *secret = memiofd->secret;
375    secret->peername = *peername;
376}
377
378memio_Private *memio_GetSecret(PRFileDesc *fd)
379{
380    PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
381    struct PRFilePrivate *secret =  memiofd->secret;
382    return (memio_Private *)secret;
383}
384
385int memio_GetReadParams(memio_Private *secret, char **buf)
386{
387    struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
388    PR_ASSERT(mb->bufsize);
389
390    *buf = &mb->buf[mb->tail];
391    return memio_buffer_unused_contiguous(mb);
392}
393
394void memio_PutReadResult(memio_Private *secret, int bytes_read)
395{
396    struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
397    PR_ASSERT(mb->bufsize);
398
399    if (bytes_read > 0) {
400        mb->tail += bytes_read;
401        if (mb->tail == mb->bufsize)
402            mb->tail = 0;
403    } else if (bytes_read == 0) {
404        /* Record EOF condition and report to caller when buffer runs dry */
405        ((PRFilePrivate *)secret)->eof = PR_TRUE;
406    } else /* if (bytes_read < 0) */ {
407        mb->last_err = bytes_read;
408    }
409}
410
411void memio_GetWriteParams(memio_Private *secret,
412                          const char **buf1, unsigned int *len1,
413                          const char **buf2, unsigned int *len2)
414{
415    struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
416    PR_ASSERT(mb->bufsize);
417
418    *buf1 = &mb->buf[mb->head];
419    *len1 = memio_buffer_used_contiguous(mb);
420    *buf2 = mb->buf;
421    *len2 = memio_buffer_wrapped_bytes(mb);
422}
423
424void memio_PutWriteResult(memio_Private *secret, int bytes_written)
425{
426    struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
427    PR_ASSERT(mb->bufsize);
428
429    if (bytes_written > 0) {
430        mb->head += bytes_written;
431        if (mb->head >= mb->bufsize)
432            mb->head -= mb->bufsize;
433    } else if (bytes_written < 0) {
434        mb->last_err = bytes_written;
435    }
436}
437
438/*--------------- private memio_buffer self-test -----------------*/
439
440/* Even a trivial unit test is very helpful when doing circular buffers. */
441/*#define TRIVIAL_SELF_TEST*/
442#ifdef TRIVIAL_SELF_TEST
443#include <stdio.h>
444
445#define TEST_BUFLEN 7
446
447#define CHECKEQ(a, b) { \
448    if ((a) != (b)) { \
449        printf("%d != %d, Test failed line %d\n", a, b, __LINE__); \
450        exit(1); \
451    } \
452}
453
454int main()
455{
456    struct memio_buffer mb;
457    char buf[100];
458    int i;
459
460    memio_buffer_new(&mb, TEST_BUFLEN);
461
462    CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1);
463    CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
464
465    CHECKEQ(memio_buffer_put(&mb, "howdy", 5), 5);
466
467    CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
468    CHECKEQ(memio_buffer_used_contiguous(&mb), 5);
469    CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
470
471    CHECKEQ(memio_buffer_put(&mb, "!", 1), 1);
472
473    CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
474    CHECKEQ(memio_buffer_used_contiguous(&mb), 6);
475    CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
476
477    CHECKEQ(memio_buffer_get(&mb, buf, 6), 6);
478    CHECKEQ(memcmp(buf, "howdy!", 6), 0);
479
480    CHECKEQ(memio_buffer_unused_contiguous(&mb), 1);
481    CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
482
483    CHECKEQ(memio_buffer_put(&mb, "01234", 5), 5);
484
485    CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
486    CHECKEQ(memio_buffer_wrapped_bytes(&mb), 4);
487    CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
488
489    CHECKEQ(memio_buffer_put(&mb, "5", 1), 1);
490
491    CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
492    CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
493
494    /* TODO: add more cases */
495
496    printf("Test passed\n");
497    exit(0);
498}
499
500#endif
501