1/*
2 * Copyright (C) 2012 Google Inc.  All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions are
6 * met:
7 *
8 *     * Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 *     * Redistributions in binary form must reproduce the above
11 * copyright notice, this list of conditions and the following disclaimer
12 * in the documentation and/or other materials provided with the
13 * distribution.
14 *     * Neither the name of Google Inc. nor the names of its
15 * contributors may be used to endorse or promote products derived from
16 * this software without specific prior written permission.
17 *
18 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 */
30
31#include "config.h"
32#include "modules/websockets/WebSocketDeflater.h"
33
34#include "platform/Logging.h"
35#include "wtf/FastMalloc.h"
36#include "wtf/HashMap.h"
37#include "wtf/StdLibExtras.h"
38#include "wtf/StringExtras.h"
39#include "wtf/text/StringHash.h"
40#include "wtf/text/WTFString.h"
41#include <zlib.h>
42
43namespace blink {
44
45static const int defaultMemLevel = 1;
46static const size_t bufferIncrementUnit = 4096;
47
48PassOwnPtr<WebSocketDeflater> WebSocketDeflater::create(int windowBits, ContextTakeOverMode contextTakeOverMode)
49{
50    return adoptPtr(new WebSocketDeflater(windowBits, contextTakeOverMode));
51}
52
53WebSocketDeflater::WebSocketDeflater(int windowBits, ContextTakeOverMode contextTakeOverMode)
54    : m_windowBits(windowBits)
55    , m_contextTakeOverMode(contextTakeOverMode)
56    , m_isBytesAdded(false)
57{
58    ASSERT(m_windowBits >= 8);
59    ASSERT(m_windowBits <= 15);
60    m_stream = adoptPtr(new z_stream);
61    memset(m_stream.get(), 0, sizeof(z_stream));
62}
63
64bool WebSocketDeflater::initialize()
65{
66    return deflateInit2(m_stream.get(), Z_DEFAULT_COMPRESSION, Z_DEFLATED, -m_windowBits, defaultMemLevel, Z_DEFAULT_STRATEGY) == Z_OK;
67}
68
69WebSocketDeflater::~WebSocketDeflater()
70{
71    int result = deflateEnd(m_stream.get());
72    if (result != Z_OK)
73        WTF_LOG(Network, "WebSocketDeflater %p Destructor deflateEnd() failed: %d is returned", this, result);
74}
75
76static void setStreamParameter(z_stream* stream, const char* inputData, size_t inputLength, char* outputData, size_t outputLength)
77{
78    stream->next_in = reinterpret_cast<Bytef*>(const_cast<char*>(inputData));
79    stream->avail_in = inputLength;
80    stream->next_out = reinterpret_cast<Bytef*>(outputData);
81    stream->avail_out = outputLength;
82}
83
84bool WebSocketDeflater::addBytes(const char* data, size_t length)
85{
86    if (!length)
87        return false;
88
89    // The estimation by deflateBound is not accurate if the zlib has some remaining input of the last compression.
90    size_t maxLength = deflateBound(m_stream.get(), length);
91    do {
92        size_t writePosition = m_buffer.size();
93        m_buffer.grow(writePosition + maxLength);
94        setStreamParameter(m_stream.get(), data, length, m_buffer.data() + writePosition, maxLength);
95        int result = deflate(m_stream.get(), Z_NO_FLUSH);
96        if (result != Z_OK)
97            return false;
98        m_buffer.shrink(writePosition + maxLength - m_stream->avail_out);
99        maxLength *= 2;
100    } while (m_stream->avail_in > 0);
101    m_isBytesAdded = true;
102    return true;
103}
104
105bool WebSocketDeflater::finish()
106{
107    if (!m_isBytesAdded) {
108        // Since consecutive calls of deflate with Z_SYNC_FLUSH and no input lead to an error,
109        // we create and return the output for the empty input manually.
110        ASSERT(!m_buffer.size());
111        m_buffer.append("\x00", 1);
112        return true;
113    }
114    while (true) {
115        size_t writePosition = m_buffer.size();
116        m_buffer.grow(writePosition + bufferIncrementUnit);
117        size_t availableCapacity = m_buffer.size() - writePosition;
118        setStreamParameter(m_stream.get(), 0, 0, m_buffer.data() + writePosition, availableCapacity);
119        int result = deflate(m_stream.get(), Z_SYNC_FLUSH);
120        m_buffer.shrink(writePosition + availableCapacity - m_stream->avail_out);
121        if (result == Z_OK)
122            break;
123        if (result != Z_BUF_ERROR)
124            return false;
125    }
126    // Remove 4 octets from the tail as the specification requires.
127    if (m_buffer.size() <= 4)
128        return false;
129    m_buffer.resize(m_buffer.size() - 4);
130    m_isBytesAdded = false;
131    return true;
132}
133
134void WebSocketDeflater::reset()
135{
136    m_buffer.clear();
137    m_isBytesAdded = false;
138    if (m_contextTakeOverMode == DoNotTakeOverContext)
139        deflateReset(m_stream.get());
140}
141
142void WebSocketDeflater::softReset()
143{
144    m_buffer.clear();
145}
146
147PassOwnPtr<WebSocketInflater> WebSocketInflater::create(int windowBits)
148{
149    return adoptPtr(new WebSocketInflater(windowBits));
150}
151
152WebSocketInflater::WebSocketInflater(int windowBits)
153    : m_windowBits(windowBits)
154{
155    m_stream = adoptPtr(new z_stream);
156    memset(m_stream.get(), 0, sizeof(z_stream));
157}
158
159bool WebSocketInflater::initialize()
160{
161    return inflateInit2(m_stream.get(), -m_windowBits) == Z_OK;
162}
163
164WebSocketInflater::~WebSocketInflater()
165{
166    int result = inflateEnd(m_stream.get());
167    if (result != Z_OK)
168        WTF_LOG(Network, "WebSocketInflater %p Destructor inflateEnd() failed: %d is returned", this, result);
169}
170
171bool WebSocketInflater::addBytes(const char* data, size_t length)
172{
173    if (!length)
174        return false;
175
176    size_t consumedSoFar = 0;
177    while (consumedSoFar < length) {
178        size_t writePosition = m_buffer.size();
179        m_buffer.grow(writePosition + bufferIncrementUnit);
180        size_t availableCapacity = m_buffer.size() - writePosition;
181        size_t remainingLength = length - consumedSoFar;
182        setStreamParameter(m_stream.get(), data + consumedSoFar, remainingLength, m_buffer.data() + writePosition, availableCapacity);
183        int result = inflate(m_stream.get(), Z_NO_FLUSH);
184        consumedSoFar += remainingLength - m_stream->avail_in;
185        m_buffer.shrink(writePosition + availableCapacity - m_stream->avail_out);
186        if (result == Z_BUF_ERROR)
187            continue;
188        if (result == Z_STREAM_END) {
189            // Received a block with BFINAL set to 1. Reset decompression state.
190            if (inflateReset(m_stream.get()) != Z_OK)
191                return false;
192            continue;
193        }
194        if (result != Z_OK)
195            return false;
196        ASSERT(remainingLength > m_stream->avail_in);
197    }
198    ASSERT(consumedSoFar == length);
199    return true;
200}
201
202bool WebSocketInflater::finish()
203{
204    static const char strippedFields[] = "\0\0\xff\xff";
205    static const size_t strippedLength = 4;
206
207    // Appends 4 octests of 0x00 0x00 0xff 0xff
208    size_t consumedSoFar = 0;
209    while (consumedSoFar < strippedLength) {
210        size_t writePosition = m_buffer.size();
211        m_buffer.grow(writePosition + bufferIncrementUnit);
212        size_t availableCapacity = m_buffer.size() - writePosition;
213        size_t remainingLength = strippedLength - consumedSoFar;
214        setStreamParameter(m_stream.get(), strippedFields + consumedSoFar, remainingLength, m_buffer.data() + writePosition, availableCapacity);
215        int result = inflate(m_stream.get(), Z_FINISH);
216        consumedSoFar += remainingLength - m_stream->avail_in;
217        m_buffer.shrink(writePosition + availableCapacity - m_stream->avail_out);
218        if (result == Z_BUF_ERROR)
219            continue;
220        if (result != Z_OK && result != Z_STREAM_END)
221            return false;
222        ASSERT(remainingLength > m_stream->avail_in);
223    }
224    ASSERT(consumedSoFar == strippedLength);
225
226    return true;
227}
228
229void WebSocketInflater::reset()
230{
231    m_buffer.clear();
232}
233
234} // namespace blink
235
236