1/* pb_encode.c -- encode a protobuf using minimal resources
2 *
3 * 2011 Petteri Aimonen <jpa@kapsi.fi>
4 */
5
6#include "pb.h"
7#include "pb_encode.h"
8
9/* Use the GCC warn_unused_result attribute to check that all return values
10 * are propagated correctly. On other compilers and gcc before 3.4.0 just
11 * ignore the annotation.
12 */
13#if !defined(__GNUC__) || ( __GNUC__ < 3) || (__GNUC__ == 3 && __GNUC_MINOR__ < 4)
14    #define checkreturn
15#else
16    #define checkreturn __attribute__((warn_unused_result))
17#endif
18
19/**************************************
20 * Declarations internal to this file *
21 **************************************/
22typedef bool (*pb_encoder_t)(pb_ostream_t *stream, const pb_field_t *field, const void *src) checkreturn;
23
24static bool checkreturn buf_write(pb_ostream_t *stream, const uint8_t *buf, size_t count);
25static bool checkreturn encode_array(pb_ostream_t *stream, const pb_field_t *field, const void *pData, size_t count, pb_encoder_t func);
26static bool checkreturn encode_field(pb_ostream_t *stream, const pb_field_t *field, const void *pData);
27static bool checkreturn default_extension_encoder(pb_ostream_t *stream, const pb_extension_t *extension);
28static bool checkreturn encode_extension_field(pb_ostream_t *stream, const pb_field_t *field, const void *pData);
29static bool checkreturn pb_enc_varint(pb_ostream_t *stream, const pb_field_t *field, const void *src);
30static bool checkreturn pb_enc_uvarint(pb_ostream_t *stream, const pb_field_t *field, const void *src);
31static bool checkreturn pb_enc_svarint(pb_ostream_t *stream, const pb_field_t *field, const void *src);
32static bool checkreturn pb_enc_fixed32(pb_ostream_t *stream, const pb_field_t *field, const void *src);
33static bool checkreturn pb_enc_fixed64(pb_ostream_t *stream, const pb_field_t *field, const void *src);
34static bool checkreturn pb_enc_bytes(pb_ostream_t *stream, const pb_field_t *field, const void *src);
35static bool checkreturn pb_enc_string(pb_ostream_t *stream, const pb_field_t *field, const void *src);
36static bool checkreturn pb_enc_submessage(pb_ostream_t *stream, const pb_field_t *field, const void *src);
37
38/* --- Function pointers to field encoders ---
39 * Order in the array must match pb_action_t LTYPE numbering.
40 */
41static const pb_encoder_t PB_ENCODERS[PB_LTYPES_COUNT] = {
42    &pb_enc_varint,
43    &pb_enc_uvarint,
44    &pb_enc_svarint,
45    &pb_enc_fixed32,
46    &pb_enc_fixed64,
47
48    &pb_enc_bytes,
49    &pb_enc_string,
50    &pb_enc_submessage,
51    NULL /* extensions */
52};
53
54/*******************************
55 * pb_ostream_t implementation *
56 *******************************/
57
58static bool checkreturn buf_write(pb_ostream_t *stream, const uint8_t *buf, size_t count)
59{
60    uint8_t *dest = (uint8_t*)stream->state;
61    stream->state = dest + count;
62
63    while (count--)
64        *dest++ = *buf++;
65
66    return true;
67}
68
69pb_ostream_t pb_ostream_from_buffer(uint8_t *buf, size_t bufsize)
70{
71    pb_ostream_t stream;
72#ifdef PB_BUFFER_ONLY
73    stream.callback = (void*)1; /* Just a marker value */
74#else
75    stream.callback = &buf_write;
76#endif
77    stream.state = buf;
78    stream.max_size = bufsize;
79    stream.bytes_written = 0;
80#ifndef PB_NO_ERRMSG
81    stream.errmsg = NULL;
82#endif
83    return stream;
84}
85
86bool checkreturn pb_write(pb_ostream_t *stream, const uint8_t *buf, size_t count)
87{
88    if (stream->callback != NULL)
89    {
90        if (stream->bytes_written + count > stream->max_size)
91            PB_RETURN_ERROR(stream, "stream full");
92
93#ifdef PB_BUFFER_ONLY
94        if (!buf_write(stream, buf, count))
95            PB_RETURN_ERROR(stream, "io error");
96#else
97        if (!stream->callback(stream, buf, count))
98            PB_RETURN_ERROR(stream, "io error");
99#endif
100    }
101
102    stream->bytes_written += count;
103    return true;
104}
105
106/*************************
107 * Encode a single field *
108 *************************/
109
110/* Encode a static array. Handles the size calculations and possible packing. */
111static bool checkreturn encode_array(pb_ostream_t *stream, const pb_field_t *field,
112                         const void *pData, size_t count, pb_encoder_t func)
113{
114    size_t i;
115    const void *p;
116    size_t size;
117
118    if (count == 0)
119        return true;
120
121    if (PB_ATYPE(field->type) != PB_ATYPE_POINTER && count > field->array_size)
122        PB_RETURN_ERROR(stream, "array max size exceeded");
123
124    /* We always pack arrays if the datatype allows it. */
125    if (PB_LTYPE(field->type) <= PB_LTYPE_LAST_PACKABLE)
126    {
127        if (!pb_encode_tag(stream, PB_WT_STRING, field->tag))
128            return false;
129
130        /* Determine the total size of packed array. */
131        if (PB_LTYPE(field->type) == PB_LTYPE_FIXED32)
132        {
133            size = 4 * count;
134        }
135        else if (PB_LTYPE(field->type) == PB_LTYPE_FIXED64)
136        {
137            size = 8 * count;
138        }
139        else
140        {
141            pb_ostream_t sizestream = PB_OSTREAM_SIZING;
142            p = pData;
143            for (i = 0; i < count; i++)
144            {
145                if (!func(&sizestream, field, p))
146                    return false;
147                p = (const char*)p + field->data_size;
148            }
149            size = sizestream.bytes_written;
150        }
151
152        if (!pb_encode_varint(stream, (uint64_t)size))
153            return false;
154
155        if (stream->callback == NULL)
156            return pb_write(stream, NULL, size); /* Just sizing.. */
157
158        /* Write the data */
159        p = pData;
160        for (i = 0; i < count; i++)
161        {
162            if (!func(stream, field, p))
163                return false;
164            p = (const char*)p + field->data_size;
165        }
166    }
167    else
168    {
169        p = pData;
170        for (i = 0; i < count; i++)
171        {
172            if (!pb_encode_tag_for_field(stream, field))
173                return false;
174
175            /* Normally the data is stored directly in the array entries, but
176             * for pointer-type string and bytes fields, the array entries are
177             * actually pointers themselves also. So we have to dereference once
178             * more to get to the actual data. */
179            if (PB_ATYPE(field->type) == PB_ATYPE_POINTER &&
180                (PB_LTYPE(field->type) == PB_LTYPE_STRING ||
181                 PB_LTYPE(field->type) == PB_LTYPE_BYTES))
182            {
183                if (!func(stream, field, *(const void* const*)p))
184                    return false;
185            }
186            else
187            {
188                if (!func(stream, field, p))
189                    return false;
190            }
191            p = (const char*)p + field->data_size;
192        }
193    }
194
195    return true;
196}
197
198/* Encode a field with static or pointer allocation, i.e. one whose data
199 * is available to the encoder directly. */
200static bool checkreturn encode_basic_field(pb_ostream_t *stream,
201    const pb_field_t *field, const void *pData)
202{
203    pb_encoder_t func;
204    const void *pSize;
205    bool implicit_has = true;
206
207    func = PB_ENCODERS[PB_LTYPE(field->type)];
208
209    if (field->size_offset)
210        pSize = (const char*)pData + field->size_offset;
211    else
212        pSize = &implicit_has;
213
214    if (PB_ATYPE(field->type) == PB_ATYPE_POINTER)
215    {
216        /* pData is a pointer to the field, which contains pointer to
217         * the data. If the 2nd pointer is NULL, it is interpreted as if
218         * the has_field was false.
219         */
220
221        pData = *(const void* const*)pData;
222        implicit_has = (pData != NULL);
223    }
224
225    switch (PB_HTYPE(field->type))
226    {
227        case PB_HTYPE_REQUIRED:
228            if (!pData)
229                PB_RETURN_ERROR(stream, "missing required field");
230            if (!pb_encode_tag_for_field(stream, field))
231                return false;
232            if (!func(stream, field, pData))
233                return false;
234            break;
235
236        case PB_HTYPE_OPTIONAL:
237            if (*(const bool*)pSize)
238            {
239                if (!pb_encode_tag_for_field(stream, field))
240                    return false;
241
242                if (!func(stream, field, pData))
243                    return false;
244            }
245            break;
246
247        case PB_HTYPE_REPEATED:
248            if (!encode_array(stream, field, pData, *(const size_t*)pSize, func))
249                return false;
250            break;
251
252        default:
253            PB_RETURN_ERROR(stream, "invalid field type");
254    }
255
256    return true;
257}
258
259/* Encode a field with callback semantics. This means that a user function is
260 * called to provide and encode the actual data. */
261static bool checkreturn encode_callback_field(pb_ostream_t *stream,
262    const pb_field_t *field, const void *pData)
263{
264    const pb_callback_t *callback = (const pb_callback_t*)pData;
265
266#ifdef PB_OLD_CALLBACK_STYLE
267    const void *arg = callback->arg;
268#else
269    void * const *arg = &(callback->arg);
270#endif
271
272    if (callback->funcs.encode != NULL)
273    {
274        if (!callback->funcs.encode(stream, field, arg))
275            PB_RETURN_ERROR(stream, "callback error");
276    }
277    return true;
278}
279
280/* Encode a single field of any callback or static type. */
281static bool checkreturn encode_field(pb_ostream_t *stream,
282    const pb_field_t *field, const void *pData)
283{
284    switch (PB_ATYPE(field->type))
285    {
286        case PB_ATYPE_STATIC:
287        case PB_ATYPE_POINTER:
288            return encode_basic_field(stream, field, pData);
289
290        case PB_ATYPE_CALLBACK:
291            return encode_callback_field(stream, field, pData);
292
293        default:
294            PB_RETURN_ERROR(stream, "invalid field type");
295    }
296}
297
298/* Default handler for extension fields. Expects to have a pb_field_t
299 * pointer in the extension->type->arg field. */
300static bool checkreturn default_extension_encoder(pb_ostream_t *stream,
301    const pb_extension_t *extension)
302{
303    const pb_field_t *field = (const pb_field_t*)extension->type->arg;
304    return encode_field(stream, field, extension->dest);
305}
306
307/* Walk through all the registered extensions and give them a chance
308 * to encode themselves. */
309static bool checkreturn encode_extension_field(pb_ostream_t *stream,
310    const pb_field_t *field, const void *pData)
311{
312    const pb_extension_t *extension = *(const pb_extension_t* const *)pData;
313    UNUSED(field);
314
315    while (extension)
316    {
317        bool status;
318        if (extension->type->encode)
319            status = extension->type->encode(stream, extension);
320        else
321            status = default_extension_encoder(stream, extension);
322
323        if (!status)
324            return false;
325
326        extension = extension->next;
327    }
328
329    return true;
330}
331
332/*********************
333 * Encode all fields *
334 *********************/
335
336bool checkreturn pb_encode(pb_ostream_t *stream, const pb_field_t fields[], const void *src_struct)
337{
338    const pb_field_t *field = fields;
339    const void *pData = src_struct;
340    size_t prev_size = 0;
341
342    while (field->tag != 0)
343    {
344        pData = (const char*)pData + prev_size + field->data_offset;
345        if (PB_ATYPE(field->type) == PB_ATYPE_POINTER)
346            prev_size = sizeof(const void*);
347        else
348            prev_size = field->data_size;
349
350        /* Special case for static arrays */
351        if (PB_ATYPE(field->type) == PB_ATYPE_STATIC &&
352            PB_HTYPE(field->type) == PB_HTYPE_REPEATED)
353        {
354            prev_size *= field->array_size;
355        }
356
357        if (PB_LTYPE(field->type) == PB_LTYPE_EXTENSION)
358        {
359            /* Special case for the extension field placeholder */
360            if (!encode_extension_field(stream, field, pData))
361                return false;
362        }
363        else
364        {
365            /* Regular field */
366            if (!encode_field(stream, field, pData))
367                return false;
368        }
369
370        field++;
371    }
372
373    return true;
374}
375
376bool pb_encode_delimited(pb_ostream_t *stream, const pb_field_t fields[], const void *src_struct)
377{
378    return pb_encode_submessage(stream, fields, src_struct);
379}
380
381bool pb_get_encoded_size(size_t *size, const pb_field_t fields[], const void *src_struct)
382{
383    pb_ostream_t stream = PB_OSTREAM_SIZING;
384
385    if (!pb_encode(&stream, fields, src_struct))
386        return false;
387
388    *size = stream.bytes_written;
389    return true;
390}
391
392/********************
393 * Helper functions *
394 ********************/
395bool checkreturn pb_encode_varint(pb_ostream_t *stream, uint64_t value)
396{
397    uint8_t buffer[10];
398    size_t i = 0;
399
400    if (value == 0)
401        return pb_write(stream, (uint8_t*)&value, 1);
402
403    while (value)
404    {
405        buffer[i] = (uint8_t)((value & 0x7F) | 0x80);
406        value >>= 7;
407        i++;
408    }
409    buffer[i-1] &= 0x7F; /* Unset top bit on last byte */
410
411    return pb_write(stream, buffer, i);
412}
413
414bool checkreturn pb_encode_svarint(pb_ostream_t *stream, int64_t value)
415{
416    uint64_t zigzagged;
417    if (value < 0)
418        zigzagged = ~((uint64_t)value << 1);
419    else
420        zigzagged = (uint64_t)value << 1;
421
422    return pb_encode_varint(stream, zigzagged);
423}
424
425bool checkreturn pb_encode_fixed32(pb_ostream_t *stream, const void *value)
426{
427    #ifdef __BIG_ENDIAN__
428    const uint8_t *bytes = value;
429    uint8_t lebytes[4];
430    lebytes[0] = bytes[3];
431    lebytes[1] = bytes[2];
432    lebytes[2] = bytes[1];
433    lebytes[3] = bytes[0];
434    return pb_write(stream, lebytes, 4);
435    #else
436    return pb_write(stream, (const uint8_t*)value, 4);
437    #endif
438}
439
440bool checkreturn pb_encode_fixed64(pb_ostream_t *stream, const void *value)
441{
442    #ifdef __BIG_ENDIAN__
443    const uint8_t *bytes = value;
444    uint8_t lebytes[8];
445    lebytes[0] = bytes[7];
446    lebytes[1] = bytes[6];
447    lebytes[2] = bytes[5];
448    lebytes[3] = bytes[4];
449    lebytes[4] = bytes[3];
450    lebytes[5] = bytes[2];
451    lebytes[6] = bytes[1];
452    lebytes[7] = bytes[0];
453    return pb_write(stream, lebytes, 8);
454    #else
455    return pb_write(stream, (const uint8_t*)value, 8);
456    #endif
457}
458
459bool checkreturn pb_encode_tag(pb_ostream_t *stream, pb_wire_type_t wiretype, uint32_t field_number)
460{
461    uint64_t tag = ((uint64_t)field_number << 3) | wiretype;
462    return pb_encode_varint(stream, tag);
463}
464
465bool checkreturn pb_encode_tag_for_field(pb_ostream_t *stream, const pb_field_t *field)
466{
467    pb_wire_type_t wiretype;
468    switch (PB_LTYPE(field->type))
469    {
470        case PB_LTYPE_VARINT:
471        case PB_LTYPE_UVARINT:
472        case PB_LTYPE_SVARINT:
473            wiretype = PB_WT_VARINT;
474            break;
475
476        case PB_LTYPE_FIXED32:
477            wiretype = PB_WT_32BIT;
478            break;
479
480        case PB_LTYPE_FIXED64:
481            wiretype = PB_WT_64BIT;
482            break;
483
484        case PB_LTYPE_BYTES:
485        case PB_LTYPE_STRING:
486        case PB_LTYPE_SUBMESSAGE:
487            wiretype = PB_WT_STRING;
488            break;
489
490        default:
491            PB_RETURN_ERROR(stream, "invalid field type");
492    }
493
494    return pb_encode_tag(stream, wiretype, field->tag);
495}
496
497bool checkreturn pb_encode_string(pb_ostream_t *stream, const uint8_t *buffer, size_t size)
498{
499    if (!pb_encode_varint(stream, (uint64_t)size))
500        return false;
501
502    return pb_write(stream, buffer, size);
503}
504
505bool checkreturn pb_encode_submessage(pb_ostream_t *stream, const pb_field_t fields[], const void *src_struct)
506{
507    /* First calculate the message size using a non-writing substream. */
508    pb_ostream_t substream = PB_OSTREAM_SIZING;
509    size_t size;
510    bool status;
511
512    if (!pb_encode(&substream, fields, src_struct))
513    {
514#ifndef PB_NO_ERRMSG
515        stream->errmsg = substream.errmsg;
516#endif
517        return false;
518    }
519
520    size = substream.bytes_written;
521
522    if (!pb_encode_varint(stream, (uint64_t)size))
523        return false;
524
525    if (stream->callback == NULL)
526        return pb_write(stream, NULL, size); /* Just sizing */
527
528    if (stream->bytes_written + size > stream->max_size)
529        PB_RETURN_ERROR(stream, "stream full");
530
531    /* Use a substream to verify that a callback doesn't write more than
532     * what it did the first time. */
533    substream.callback = stream->callback;
534    substream.state = stream->state;
535    substream.max_size = size;
536    substream.bytes_written = 0;
537#ifndef PB_NO_ERRMSG
538    substream.errmsg = NULL;
539#endif
540
541    status = pb_encode(&substream, fields, src_struct);
542
543    stream->bytes_written += substream.bytes_written;
544    stream->state = substream.state;
545#ifndef PB_NO_ERRMSG
546    stream->errmsg = substream.errmsg;
547#endif
548
549    if (substream.bytes_written != size)
550        PB_RETURN_ERROR(stream, "submsg size changed");
551
552    return status;
553}
554
555/* Field encoders */
556
557static bool checkreturn pb_enc_varint(pb_ostream_t *stream, const pb_field_t *field, const void *src)
558{
559    int64_t value = 0;
560
561    /* Cases 1 and 2 are for compilers that have smaller types for bool
562     * or enums. */
563    switch (field->data_size)
564    {
565        case 1: value = *(const int8_t*)src; break;
566        case 2: value = *(const int16_t*)src; break;
567        case 4: value = *(const int32_t*)src; break;
568        case 8: value = *(const int64_t*)src; break;
569        default: PB_RETURN_ERROR(stream, "invalid data_size");
570    }
571
572    return pb_encode_varint(stream, (uint64_t)value);
573}
574
575static bool checkreturn pb_enc_uvarint(pb_ostream_t *stream, const pb_field_t *field, const void *src)
576{
577    uint64_t value = 0;
578
579    switch (field->data_size)
580    {
581        case 4: value = *(const uint32_t*)src; break;
582        case 8: value = *(const uint64_t*)src; break;
583        default: PB_RETURN_ERROR(stream, "invalid data_size");
584    }
585
586    return pb_encode_varint(stream, value);
587}
588
589static bool checkreturn pb_enc_svarint(pb_ostream_t *stream, const pb_field_t *field, const void *src)
590{
591    int64_t value = 0;
592
593    switch (field->data_size)
594    {
595        case 4: value = *(const int32_t*)src; break;
596        case 8: value = *(const int64_t*)src; break;
597        default: PB_RETURN_ERROR(stream, "invalid data_size");
598    }
599
600    return pb_encode_svarint(stream, value);
601}
602
603static bool checkreturn pb_enc_fixed64(pb_ostream_t *stream, const pb_field_t *field, const void *src)
604{
605    UNUSED(field);
606    return pb_encode_fixed64(stream, src);
607}
608
609static bool checkreturn pb_enc_fixed32(pb_ostream_t *stream, const pb_field_t *field, const void *src)
610{
611    UNUSED(field);
612    return pb_encode_fixed32(stream, src);
613}
614
615static bool checkreturn pb_enc_bytes(pb_ostream_t *stream, const pb_field_t *field, const void *src)
616{
617    const pb_bytes_array_t *bytes = (const pb_bytes_array_t*)src;
618
619    if (src == NULL)
620    {
621        /* Threat null pointer as an empty bytes field */
622        return pb_encode_string(stream, NULL, 0);
623    }
624
625    if (PB_ATYPE(field->type) == PB_ATYPE_STATIC &&
626        PB_BYTES_ARRAY_T_ALLOCSIZE(bytes->size) > field->data_size)
627    {
628        PB_RETURN_ERROR(stream, "bytes size exceeded");
629    }
630
631    return pb_encode_string(stream, bytes->bytes, bytes->size);
632}
633
634static bool checkreturn pb_enc_string(pb_ostream_t *stream, const pb_field_t *field, const void *src)
635{
636    /* strnlen() is not always available, so just use a loop */
637    size_t size = 0;
638    size_t max_size = field->data_size;
639    const char *p = (const char*)src;
640
641    if (PB_ATYPE(field->type) == PB_ATYPE_POINTER)
642        max_size = (size_t)-1;
643
644    if (src == NULL)
645    {
646        size = 0; /* Threat null pointer as an empty string */
647    }
648    else
649    {
650        while (size < max_size && *p != '\0')
651        {
652            size++;
653            p++;
654        }
655    }
656
657    return pb_encode_string(stream, (const uint8_t*)src, size);
658}
659
660static bool checkreturn pb_enc_submessage(pb_ostream_t *stream, const pb_field_t *field, const void *src)
661{
662    if (field->ptr == NULL)
663        PB_RETURN_ERROR(stream, "invalid field descriptor");
664
665    return pb_encode_submessage(stream, (const pb_field_t*)field->ptr, src);
666}
667
668