1// Copyright 2015 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <brillo/streams/stream.h>
6
7#include <algorithm>
8
9#include <base/bind.h>
10#include <brillo/message_loops/message_loop.h>
11#include <brillo/pointer_utils.h>
12#include <brillo/streams/stream_errors.h>
13#include <brillo/streams/stream_utils.h>
14
15namespace brillo {
16
17bool Stream::TruncateBlocking(ErrorPtr* error) {
18  return SetSizeBlocking(GetPosition(), error);
19}
20
21bool Stream::SetPosition(uint64_t position, ErrorPtr* error) {
22  if (!stream_utils::CheckInt64Overflow(FROM_HERE, position, 0, error))
23    return false;
24  return Seek(position, Whence::FROM_BEGIN, nullptr, error);
25}
26
27bool Stream::ReadAsync(void* buffer,
28                       size_t size_to_read,
29                       const base::Callback<void(size_t)>& success_callback,
30                       const ErrorCallback& error_callback,
31                       ErrorPtr* error) {
32  if (is_async_read_pending_) {
33    Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
34                 errors::stream::kOperationNotSupported,
35                 "Another asynchronous operation is still pending");
36    return false;
37  }
38
39  auto callback = base::Bind(&Stream::IgnoreEOSCallback, success_callback);
40  // If we can read some data right away non-blocking we should still run the
41  // callback from the main loop, so we pass true here for force_async_callback.
42  return ReadAsyncImpl(buffer, size_to_read, callback, error_callback, error,
43                       true);
44}
45
46bool Stream::ReadAllAsync(void* buffer,
47                          size_t size_to_read,
48                          const base::Closure& success_callback,
49                          const ErrorCallback& error_callback,
50                          ErrorPtr* error) {
51  if (is_async_read_pending_) {
52    Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
53                 errors::stream::kOperationNotSupported,
54                 "Another asynchronous operation is still pending");
55    return false;
56  }
57
58  auto callback = base::Bind(&Stream::ReadAllAsyncCallback,
59                             weak_ptr_factory_.GetWeakPtr(), buffer,
60                             size_to_read, success_callback, error_callback);
61  return ReadAsyncImpl(buffer, size_to_read, callback, error_callback, error,
62                       true);
63}
64
65bool Stream::ReadBlocking(void* buffer,
66                          size_t size_to_read,
67                          size_t* size_read,
68                          ErrorPtr* error) {
69  for (;;) {
70    bool eos = false;
71    if (!ReadNonBlocking(buffer, size_to_read, size_read, &eos, error))
72      return false;
73
74    if (*size_read > 0 || eos)
75      break;
76
77    if (!WaitForDataBlocking(AccessMode::READ, base::TimeDelta::Max(), nullptr,
78                             error)) {
79      return false;
80    }
81  }
82  return true;
83}
84
85bool Stream::ReadAllBlocking(void* buffer,
86                             size_t size_to_read,
87                             ErrorPtr* error) {
88  while (size_to_read > 0) {
89    size_t size_read = 0;
90    if (!ReadBlocking(buffer, size_to_read, &size_read, error))
91      return false;
92
93    if (size_read == 0)
94      return stream_utils::ErrorReadPastEndOfStream(FROM_HERE, error);
95
96    size_to_read -= size_read;
97    buffer = AdvancePointer(buffer, size_read);
98  }
99  return true;
100}
101
102bool Stream::WriteAsync(const void* buffer,
103                        size_t size_to_write,
104                        const base::Callback<void(size_t)>& success_callback,
105                        const ErrorCallback& error_callback,
106                        ErrorPtr* error) {
107  if (is_async_write_pending_) {
108    Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
109                 errors::stream::kOperationNotSupported,
110                 "Another asynchronous operation is still pending");
111    return false;
112  }
113  // If we can read some data right away non-blocking we should still run the
114  // callback from the main loop, so we pass true here for force_async_callback.
115  return WriteAsyncImpl(buffer, size_to_write, success_callback, error_callback,
116                        error, true);
117}
118
119bool Stream::WriteAllAsync(const void* buffer,
120                           size_t size_to_write,
121                           const base::Closure& success_callback,
122                           const ErrorCallback& error_callback,
123                           ErrorPtr* error) {
124  if (is_async_write_pending_) {
125    Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
126                 errors::stream::kOperationNotSupported,
127                 "Another asynchronous operation is still pending");
128    return false;
129  }
130
131  auto callback = base::Bind(&Stream::WriteAllAsyncCallback,
132                             weak_ptr_factory_.GetWeakPtr(), buffer,
133                             size_to_write, success_callback, error_callback);
134  return WriteAsyncImpl(buffer, size_to_write, callback, error_callback, error,
135                        true);
136}
137
138bool Stream::WriteBlocking(const void* buffer,
139                           size_t size_to_write,
140                           size_t* size_written,
141                           ErrorPtr* error) {
142  for (;;) {
143    if (!WriteNonBlocking(buffer, size_to_write, size_written, error))
144      return false;
145
146    if (*size_written > 0 || size_to_write == 0)
147      break;
148
149    if (!WaitForDataBlocking(AccessMode::WRITE, base::TimeDelta::Max(), nullptr,
150                             error)) {
151      return false;
152    }
153  }
154  return true;
155}
156
157bool Stream::WriteAllBlocking(const void* buffer,
158                              size_t size_to_write,
159                              ErrorPtr* error) {
160  while (size_to_write > 0) {
161    size_t size_written = 0;
162    if (!WriteBlocking(buffer, size_to_write, &size_written, error))
163      return false;
164
165    if (size_written == 0) {
166      Error::AddTo(error, FROM_HERE, errors::stream::kDomain,
167                   errors::stream::kPartialData,
168                   "Failed to write all the data");
169      return false;
170    }
171    size_to_write -= size_written;
172    buffer = AdvancePointer(buffer, size_written);
173  }
174  return true;
175}
176
177bool Stream::FlushAsync(const base::Closure& success_callback,
178                        const ErrorCallback& error_callback,
179                        ErrorPtr* /* error */) {
180  auto callback = base::Bind(&Stream::FlushAsyncCallback,
181                             weak_ptr_factory_.GetWeakPtr(),
182                             success_callback, error_callback);
183  MessageLoop::current()->PostTask(FROM_HERE, callback);
184  return true;
185}
186
187void Stream::IgnoreEOSCallback(
188    const base::Callback<void(size_t)>& success_callback,
189    size_t bytes,
190    bool /* eos */) {
191  success_callback.Run(bytes);
192}
193
194bool Stream::ReadAsyncImpl(
195    void* buffer,
196    size_t size_to_read,
197    const base::Callback<void(size_t, bool)>& success_callback,
198    const ErrorCallback& error_callback,
199    ErrorPtr* error,
200    bool force_async_callback) {
201  CHECK(!is_async_read_pending_);
202  // We set this value to true early in the function so calling others will
203  // prevent us from calling WaitForData() to make calls to
204  // ReadAsync() fail while we run WaitForData().
205  is_async_read_pending_ = true;
206
207  size_t read = 0;
208  bool eos = false;
209  if (!ReadNonBlocking(buffer, size_to_read, &read, &eos, error))
210    return false;
211
212  if (read > 0 || eos) {
213    if (force_async_callback) {
214      MessageLoop::current()->PostTask(
215          FROM_HERE,
216          base::Bind(&Stream::OnReadAsyncDone, weak_ptr_factory_.GetWeakPtr(),
217                     success_callback, read, eos));
218    } else {
219      is_async_read_pending_ = false;
220      success_callback.Run(read, eos);
221    }
222    return true;
223  }
224
225  is_async_read_pending_ = WaitForData(
226      AccessMode::READ,
227      base::Bind(&Stream::OnReadAvailable, weak_ptr_factory_.GetWeakPtr(),
228                 buffer, size_to_read, success_callback, error_callback),
229      error);
230  return is_async_read_pending_;
231}
232
233void Stream::OnReadAsyncDone(
234    const base::Callback<void(size_t, bool)>& success_callback,
235    size_t bytes_read,
236    bool eos) {
237  is_async_read_pending_ = false;
238  success_callback.Run(bytes_read, eos);
239}
240
241void Stream::OnReadAvailable(
242    void* buffer,
243    size_t size_to_read,
244    const base::Callback<void(size_t, bool)>& success_callback,
245    const ErrorCallback& error_callback,
246    AccessMode mode) {
247  CHECK(stream_utils::IsReadAccessMode(mode));
248  CHECK(is_async_read_pending_);
249  is_async_read_pending_ = false;
250  ErrorPtr error;
251  // Just reschedule the read operation but don't need to run the callback from
252  // the main loop since we are already running on a callback.
253  if (!ReadAsyncImpl(buffer, size_to_read, success_callback, error_callback,
254                     &error, false)) {
255    error_callback.Run(error.get());
256  }
257}
258
259bool Stream::WriteAsyncImpl(
260    const void* buffer,
261    size_t size_to_write,
262    const base::Callback<void(size_t)>& success_callback,
263    const ErrorCallback& error_callback,
264    ErrorPtr* error,
265    bool force_async_callback) {
266  CHECK(!is_async_write_pending_);
267  // We set this value to true early in the function so calling others will
268  // prevent us from calling WaitForData() to make calls to
269  // ReadAsync() fail while we run WaitForData().
270  is_async_write_pending_ = true;
271
272  size_t written = 0;
273  if (!WriteNonBlocking(buffer, size_to_write, &written, error))
274    return false;
275
276  if (written > 0) {
277    if (force_async_callback) {
278      MessageLoop::current()->PostTask(
279          FROM_HERE,
280          base::Bind(&Stream::OnWriteAsyncDone, weak_ptr_factory_.GetWeakPtr(),
281                     success_callback, written));
282    } else {
283      is_async_write_pending_ = false;
284      success_callback.Run(written);
285    }
286    return true;
287  }
288  is_async_write_pending_ = WaitForData(
289      AccessMode::WRITE,
290      base::Bind(&Stream::OnWriteAvailable, weak_ptr_factory_.GetWeakPtr(),
291                 buffer, size_to_write, success_callback, error_callback),
292      error);
293  return is_async_write_pending_;
294}
295
296void Stream::OnWriteAsyncDone(
297    const base::Callback<void(size_t)>& success_callback,
298    size_t size_written) {
299  is_async_write_pending_ = false;
300  success_callback.Run(size_written);
301}
302
303void Stream::OnWriteAvailable(
304    const void* buffer,
305    size_t size,
306    const base::Callback<void(size_t)>& success_callback,
307    const ErrorCallback& error_callback,
308    AccessMode mode) {
309  CHECK(stream_utils::IsWriteAccessMode(mode));
310  CHECK(is_async_write_pending_);
311  is_async_write_pending_ = false;
312  ErrorPtr error;
313  // Just reschedule the read operation but don't need to run the callback from
314  // the main loop since we are already running on a callback.
315  if (!WriteAsyncImpl(buffer, size, success_callback, error_callback, &error,
316                      false)) {
317    error_callback.Run(error.get());
318  }
319}
320
321void Stream::ReadAllAsyncCallback(void* buffer,
322                                  size_t size_to_read,
323                                  const base::Closure& success_callback,
324                                  const ErrorCallback& error_callback,
325                                  size_t size_read,
326                                  bool eos) {
327  ErrorPtr error;
328  size_to_read -= size_read;
329  if (size_to_read != 0 && eos) {
330    stream_utils::ErrorReadPastEndOfStream(FROM_HERE, &error);
331    error_callback.Run(error.get());
332    return;
333  }
334
335  if (size_to_read) {
336    buffer = AdvancePointer(buffer, size_read);
337    auto callback = base::Bind(&Stream::ReadAllAsyncCallback,
338                               weak_ptr_factory_.GetWeakPtr(), buffer,
339                               size_to_read, success_callback, error_callback);
340    if (!ReadAsyncImpl(buffer, size_to_read, callback, error_callback, &error,
341                       false)) {
342      error_callback.Run(error.get());
343    }
344  } else {
345    success_callback.Run();
346  }
347}
348
349void Stream::WriteAllAsyncCallback(const void* buffer,
350                                   size_t size_to_write,
351                                   const base::Closure& success_callback,
352                                   const ErrorCallback& error_callback,
353                                   size_t size_written) {
354  ErrorPtr error;
355  if (size_to_write != 0 && size_written == 0) {
356    Error::AddTo(&error, FROM_HERE, errors::stream::kDomain,
357                 errors::stream::kPartialData, "Failed to write all the data");
358    error_callback.Run(error.get());
359    return;
360  }
361  size_to_write -= size_written;
362  if (size_to_write) {
363    buffer = AdvancePointer(buffer, size_written);
364    auto callback = base::Bind(&Stream::WriteAllAsyncCallback,
365                               weak_ptr_factory_.GetWeakPtr(), buffer,
366                               size_to_write, success_callback, error_callback);
367    if (!WriteAsyncImpl(buffer, size_to_write, callback, error_callback, &error,
368                        false)) {
369      error_callback.Run(error.get());
370    }
371  } else {
372    success_callback.Run();
373  }
374}
375
376void Stream::FlushAsyncCallback(const base::Closure& success_callback,
377                                const ErrorCallback& error_callback) {
378  ErrorPtr error;
379  if (FlushBlocking(&error)) {
380    success_callback.Run();
381  } else {
382    error_callback.Run(error.get());
383  }
384}
385
386void Stream::CancelPendingAsyncOperations() {
387  weak_ptr_factory_.InvalidateWeakPtrs();
388  is_async_read_pending_ = false;
389  is_async_write_pending_ = false;
390}
391
392}  // namespace brillo
393