1// Copyright 2015 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <brillo/streams/stream_utils.h>
6
7#include <limits>
8
9#include <base/bind.h>
10#include <brillo/message_loops/message_loop.h>
11#include <brillo/streams/stream_errors.h>
12
13namespace brillo {
14namespace stream_utils {
15
16namespace {
17
18// Status of asynchronous CopyData operation.
19struct CopyDataState {
20  brillo::StreamPtr in_stream;
21  brillo::StreamPtr out_stream;
22  std::vector<uint8_t> buffer;
23  uint64_t remaining_to_copy;
24  uint64_t size_copied;
25  CopyDataSuccessCallback success_callback;
26  CopyDataErrorCallback error_callback;
27};
28
29// Async CopyData I/O error callback.
30void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
31                     const brillo::Error* error) {
32  state->error_callback.Run(std::move(state->in_stream),
33                            std::move(state->out_stream), error);
34}
35
36// Forward declaration.
37void PerformRead(const std::shared_ptr<CopyDataState>& state);
38
39// Callback from read operation for CopyData. Writes the read data to the output
40// stream and invokes PerformRead when done to restart the copy cycle.
41void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
42  if (size == 0) {
43    state->success_callback.Run(std::move(state->in_stream),
44                                std::move(state->out_stream),
45                                state->size_copied);
46    return;
47  }
48  state->size_copied += size;
49  CHECK_GE(state->remaining_to_copy, size);
50  state->remaining_to_copy -= size;
51
52  brillo::ErrorPtr error;
53  bool success = state->out_stream->WriteAllAsync(
54      state->buffer.data(), size, base::Bind(&PerformRead, state),
55      base::Bind(&OnCopyDataError, state), &error);
56
57  if (!success)
58    OnCopyDataError(state, error.get());
59}
60
61// Performs the read part of asynchronous CopyData operation. Reads the data
62// from input stream and invokes PerformWrite when done to write the data to
63// the output stream.
64void PerformRead(const std::shared_ptr<CopyDataState>& state) {
65  brillo::ErrorPtr error;
66  const uint64_t buffer_size = state->buffer.size();
67  // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
68  // also not overflow size_t, so the static_cast below is safe.
69  size_t size_to_read =
70      static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
71  if (size_to_read == 0)
72    return PerformWrite(state, 0);  // Nothing more to read. Finish operation.
73  bool success = state->in_stream->ReadAsync(
74      state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
75      base::Bind(OnCopyDataError, state), &error);
76
77  if (!success)
78    OnCopyDataError(state, error.get());
79}
80
81}  // anonymous namespace
82
83bool ErrorStreamClosed(const tracked_objects::Location& location,
84                       ErrorPtr* error) {
85  Error::AddTo(error,
86               location,
87               errors::stream::kDomain,
88               errors::stream::kStreamClosed,
89               "Stream is closed");
90  return false;
91}
92
93bool ErrorOperationNotSupported(const tracked_objects::Location& location,
94                                ErrorPtr* error) {
95  Error::AddTo(error,
96               location,
97               errors::stream::kDomain,
98               errors::stream::kOperationNotSupported,
99               "Stream operation not supported");
100  return false;
101}
102
103bool ErrorReadPastEndOfStream(const tracked_objects::Location& location,
104                              ErrorPtr* error) {
105  Error::AddTo(error,
106               location,
107               errors::stream::kDomain,
108               errors::stream::kPartialData,
109               "Reading past the end of stream");
110  return false;
111}
112
113bool ErrorOperationTimeout(const tracked_objects::Location& location,
114                           ErrorPtr* error) {
115  Error::AddTo(error,
116               location,
117               errors::stream::kDomain,
118               errors::stream::kTimeout,
119               "Operation timed out");
120  return false;
121}
122
123bool CheckInt64Overflow(const tracked_objects::Location& location,
124                        uint64_t position,
125                        int64_t offset,
126                        ErrorPtr* error) {
127  if (offset < 0) {
128    // Subtracting the offset. Make sure we do not underflow.
129    uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
130    if (position >= unsigned_offset)
131      return true;
132  } else {
133    // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
134    if (position <= std::numeric_limits<uint64_t>::max() - offset) {
135      // We definitely will not overflow the unsigned 64 bit integer.
136      // Now check that we end up within the limits of signed 64 bit integer.
137      uint64_t new_position = position + offset;
138      uint64_t max = std::numeric_limits<int64_t>::max();
139      if (new_position <= max)
140        return true;
141    }
142  }
143  Error::AddTo(error,
144               location,
145               errors::stream::kDomain,
146               errors::stream::kInvalidParameter,
147               "The stream offset value is out of range");
148  return false;
149}
150
151bool CalculateStreamPosition(const tracked_objects::Location& location,
152                             int64_t offset,
153                             Stream::Whence whence,
154                             uint64_t current_position,
155                             uint64_t stream_size,
156                             uint64_t* new_position,
157                             ErrorPtr* error) {
158  uint64_t pos = 0;
159  switch (whence) {
160    case Stream::Whence::FROM_BEGIN:
161      pos = 0;
162      break;
163
164    case Stream::Whence::FROM_CURRENT:
165      pos = current_position;
166      break;
167
168    case Stream::Whence::FROM_END:
169      pos = stream_size;
170      break;
171
172    default:
173      Error::AddTo(error,
174                   location,
175                   errors::stream::kDomain,
176                   errors::stream::kInvalidParameter,
177                   "Invalid stream position whence");
178      return false;
179  }
180
181  if (!CheckInt64Overflow(location, pos, offset, error))
182    return false;
183
184  *new_position = static_cast<uint64_t>(pos + offset);
185  return true;
186}
187
188void CopyData(StreamPtr in_stream,
189              StreamPtr out_stream,
190              const CopyDataSuccessCallback& success_callback,
191              const CopyDataErrorCallback& error_callback) {
192  CopyData(std::move(in_stream), std::move(out_stream),
193           std::numeric_limits<uint64_t>::max(), 4096, success_callback,
194           error_callback);
195}
196
197void CopyData(StreamPtr in_stream,
198              StreamPtr out_stream,
199              uint64_t max_size_to_copy,
200              size_t buffer_size,
201              const CopyDataSuccessCallback& success_callback,
202              const CopyDataErrorCallback& error_callback) {
203  auto state = std::make_shared<CopyDataState>();
204  state->in_stream = std::move(in_stream);
205  state->out_stream = std::move(out_stream);
206  state->buffer.resize(buffer_size);
207  state->remaining_to_copy = max_size_to_copy;
208  state->size_copied = 0;
209  state->success_callback = success_callback;
210  state->error_callback = error_callback;
211  brillo::MessageLoop::current()->PostTask(FROM_HERE,
212                                             base::Bind(&PerformRead, state));
213}
214
215}  // namespace stream_utils
216}  // namespace brillo
217