1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "mojo/system/message_in_transit.h"
6
7#include <string.h>
8
9#include "base/compiler_specific.h"
10#include "base/logging.h"
11#include "mojo/system/constants.h"
12#include "mojo/system/transport_data.h"
13
14namespace mojo {
15namespace system {
16
17STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Type
18    MessageInTransit::kTypeMessagePipeEndpoint;
19STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Type
20    MessageInTransit::kTypeMessagePipe;
21STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Type
22    MessageInTransit::kTypeChannel;
23STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Type
24    MessageInTransit::kTypeRawChannel;
25STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Subtype
26    MessageInTransit::kSubtypeMessagePipeEndpointData;
27STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Subtype
28    MessageInTransit::kSubtypeChannelRunMessagePipeEndpoint;
29STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Subtype
30    MessageInTransit::kSubtypeChannelRemoveMessagePipeEndpoint;
31STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Subtype
32    MessageInTransit::kSubtypeChannelRemoveMessagePipeEndpointAck;
33STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::Subtype
34    MessageInTransit::kSubtypeRawChannelPosixExtraPlatformHandles;
35STATIC_CONST_MEMBER_DEFINITION const MessageInTransit::EndpointId
36    MessageInTransit::kInvalidEndpointId;
37STATIC_CONST_MEMBER_DEFINITION const size_t MessageInTransit::kMessageAlignment;
38
39struct MessageInTransit::PrivateStructForCompileAsserts {
40  // The size of |Header| must be a multiple of the alignment.
41  static_assert(sizeof(Header) % kMessageAlignment == 0,
42                "sizeof(MessageInTransit::Header) invalid");
43  // Avoid dangerous situations, but making sure that the size of the "header" +
44  // the size of the data fits into a 31-bit number.
45  static_assert(static_cast<uint64_t>(sizeof(Header)) + kMaxMessageNumBytes <=
46                    0x7fffffffULL,
47                "kMaxMessageNumBytes too big");
48
49  // We assume (to avoid extra rounding code) that the maximum message (data)
50  // size is a multiple of the alignment.
51  static_assert(kMaxMessageNumBytes % kMessageAlignment == 0,
52                "kMessageAlignment not a multiple of alignment");
53};
54
55MessageInTransit::View::View(size_t message_size, const void* buffer)
56    : buffer_(buffer) {
57  size_t next_message_size = 0;
58  DCHECK(MessageInTransit::GetNextMessageSize(
59      buffer_, message_size, &next_message_size));
60  DCHECK_EQ(message_size, next_message_size);
61  // This should be equivalent.
62  DCHECK_EQ(message_size, total_size());
63}
64
65bool MessageInTransit::View::IsValid(size_t serialized_platform_handle_size,
66                                     const char** error_message) const {
67  // Note: This also implies a check on the |main_buffer_size()|, which is just
68  // |RoundUpMessageAlignment(sizeof(Header) + num_bytes())|.
69  if (num_bytes() > kMaxMessageNumBytes) {
70    *error_message = "Message data payload too large";
71    return false;
72  }
73
74  if (transport_data_buffer_size() > 0) {
75    const char* e =
76        TransportData::ValidateBuffer(serialized_platform_handle_size,
77                                      transport_data_buffer(),
78                                      transport_data_buffer_size());
79    if (e) {
80      *error_message = e;
81      return false;
82    }
83  }
84
85  return true;
86}
87
88MessageInTransit::MessageInTransit(Type type,
89                                   Subtype subtype,
90                                   uint32_t num_bytes,
91                                   const void* bytes)
92    : main_buffer_size_(RoundUpMessageAlignment(sizeof(Header) + num_bytes)),
93      main_buffer_(static_cast<char*>(
94          base::AlignedAlloc(main_buffer_size_, kMessageAlignment))) {
95  ConstructorHelper(type, subtype, num_bytes);
96  if (bytes) {
97    memcpy(MessageInTransit::bytes(), bytes, num_bytes);
98    memset(static_cast<char*>(MessageInTransit::bytes()) + num_bytes,
99           0,
100           main_buffer_size_ - sizeof(Header) - num_bytes);
101  } else {
102    memset(MessageInTransit::bytes(), 0, main_buffer_size_ - sizeof(Header));
103  }
104}
105
106MessageInTransit::MessageInTransit(Type type,
107                                   Subtype subtype,
108                                   uint32_t num_bytes,
109                                   UserPointer<const void> bytes)
110    : main_buffer_size_(RoundUpMessageAlignment(sizeof(Header) + num_bytes)),
111      main_buffer_(static_cast<char*>(
112          base::AlignedAlloc(main_buffer_size_, kMessageAlignment))) {
113  ConstructorHelper(type, subtype, num_bytes);
114  bytes.GetArray(MessageInTransit::bytes(), num_bytes);
115}
116
117MessageInTransit::MessageInTransit(const View& message_view)
118    : main_buffer_size_(message_view.main_buffer_size()),
119      main_buffer_(static_cast<char*>(
120          base::AlignedAlloc(main_buffer_size_, kMessageAlignment))) {
121  DCHECK_GE(main_buffer_size_, sizeof(Header));
122  DCHECK_EQ(main_buffer_size_ % kMessageAlignment, 0u);
123
124  memcpy(main_buffer_.get(), message_view.main_buffer(), main_buffer_size_);
125  DCHECK_EQ(main_buffer_size_,
126            RoundUpMessageAlignment(sizeof(Header) + num_bytes()));
127}
128
129MessageInTransit::~MessageInTransit() {
130  if (dispatchers_) {
131    for (size_t i = 0; i < dispatchers_->size(); i++) {
132      if (!(*dispatchers_)[i].get())
133        continue;
134
135      DCHECK((*dispatchers_)[i]->HasOneRef());
136      (*dispatchers_)[i]->Close();
137    }
138  }
139}
140
141// static
142bool MessageInTransit::GetNextMessageSize(const void* buffer,
143                                          size_t buffer_size,
144                                          size_t* next_message_size) {
145  DCHECK(next_message_size);
146  if (!buffer_size)
147    return false;
148  DCHECK(buffer);
149  DCHECK_EQ(
150      reinterpret_cast<uintptr_t>(buffer) % MessageInTransit::kMessageAlignment,
151      0u);
152
153  if (buffer_size < sizeof(Header))
154    return false;
155
156  const Header* header = static_cast<const Header*>(buffer);
157  *next_message_size = header->total_size;
158  DCHECK_EQ(*next_message_size % kMessageAlignment, 0u);
159  return true;
160}
161
162void MessageInTransit::SetDispatchers(
163    scoped_ptr<DispatcherVector> dispatchers) {
164  DCHECK(dispatchers);
165  DCHECK(!dispatchers_);
166  DCHECK(!transport_data_);
167
168  dispatchers_ = dispatchers.Pass();
169#ifndef NDEBUG
170  for (size_t i = 0; i < dispatchers_->size(); i++)
171    DCHECK(!(*dispatchers_)[i].get() || (*dispatchers_)[i]->HasOneRef());
172#endif
173}
174
175void MessageInTransit::SetTransportData(
176    scoped_ptr<TransportData> transport_data) {
177  DCHECK(transport_data);
178  DCHECK(!transport_data_);
179  DCHECK(!dispatchers_);
180
181  transport_data_ = transport_data.Pass();
182}
183
184void MessageInTransit::SerializeAndCloseDispatchers(Channel* channel) {
185  DCHECK(channel);
186  DCHECK(!transport_data_);
187
188  if (!dispatchers_ || !dispatchers_->size())
189    return;
190
191  transport_data_.reset(new TransportData(dispatchers_.Pass(), channel));
192
193  // Update the sizes in the message header.
194  UpdateTotalSize();
195}
196
197void MessageInTransit::ConstructorHelper(Type type,
198                                         Subtype subtype,
199                                         uint32_t num_bytes) {
200  DCHECK_LE(num_bytes, kMaxMessageNumBytes);
201
202  // |total_size| is updated below, from the other values.
203  header()->type = type;
204  header()->subtype = subtype;
205  header()->source_id = kInvalidEndpointId;
206  header()->destination_id = kInvalidEndpointId;
207  header()->num_bytes = num_bytes;
208  header()->unused = 0;
209  // Note: If dispatchers are subsequently attached, then |total_size| will have
210  // to be adjusted.
211  UpdateTotalSize();
212}
213
214void MessageInTransit::UpdateTotalSize() {
215  DCHECK_EQ(main_buffer_size_ % kMessageAlignment, 0u);
216  header()->total_size = static_cast<uint32_t>(main_buffer_size_);
217  if (transport_data_) {
218    header()->total_size +=
219        static_cast<uint32_t>(transport_data_->buffer_size());
220  }
221}
222
223}  // namespace system
224}  // namespace mojo
225