1#!/usr/bin/python
2# Copyright 2014 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6# pylint: disable=W0104,W0106,F0401,R0201
7
8import optparse
9import os.path
10import sys
11
12import interface
13
14
15def _ScriptDir():
16  return os.path.dirname(os.path.abspath(__file__))
17
18
19def _GetDirAbove(dirname):
20  """Returns the directory "above" this file containing |dirname| (which must
21  also be "above" this file)."""
22  path = _ScriptDir()
23  while True:
24    path, tail = os.path.split(path)
25    assert tail
26    if tail == dirname:
27      return path
28
29
30def _AddThirdPartyImportPath():
31  sys.path.append(os.path.join(_GetDirAbove('mojo'), 'third_party'))
32
33
34_AddThirdPartyImportPath()
35import jinja2
36
37loader = jinja2.FileSystemLoader(_ScriptDir())
38jinja_env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
39
40
41# Accumulate lines of code with varying levels of indentation.
42class CodeWriter(object):
43  def __init__(self):
44    self._lines = []
45    self._margin = ''
46    self._margin_stack = []
47
48  def __lshift__(self, line):
49    self._lines.append((self._margin + line).rstrip())
50
51  def PushMargin(self):
52    self._margin_stack.append(self._margin)
53    self._margin += '  '
54
55  def PopMargin(self):
56    self._margin = self._margin_stack.pop()
57
58  def GetValue(self):
59    return '\n'.join(self._lines).rstrip() + '\n'
60
61  def Indent(self):
62    return Indent(self)
63
64
65# Context handler that automatically indents and dedents a CodeWriter
66class Indent(object):
67  def __init__(self, writer):
68    self._writer = writer
69
70  def __enter__(self):
71    self._writer.PushMargin()
72
73  def __exit__(self, type_, value, traceback):
74    self._writer.PopMargin()
75
76
77def TemplateFile(name):
78  return os.path.join(os.path.dirname(__file__), name)
79
80
81# Wraps comma separated lists as needed.
82def Wrap(pre, items, post):
83  complete = pre + ', '.join(items) + post
84  if len(complete) <= 80:
85    return [complete]
86  lines = [pre]
87  indent = '    '
88  for i, item in enumerate(items):
89    if i < len(items) - 1:
90      lines.append(indent + item + ',')
91    else:
92      lines.append(indent + item + post)
93  return lines
94
95
96def GeneratorWarning():
97  return ('// WARNING this file was generated by %s\n// Do not edit by hand.' %
98          os.path.basename(__file__))
99
100
101# Untrusted library implementing the public Mojo API.
102def GenerateLibMojo(functions, out):
103  template = jinja_env.get_template('libmojo.cc.tmpl')
104
105  code = CodeWriter()
106
107  for f in functions:
108    for line in Wrap('%s %s(' % (f.return_type, f.name), f.ParamList(), ') {'):
109      code << line
110
111    # 2 extra parameters: message ID and return value.
112    num_params = len(f.params) + 2
113
114    with code.Indent():
115      code << 'uint32_t params[%d];' % num_params
116      return_type = f.result_param.base_type
117      if return_type == 'MojoResult':
118        default = 'MOJO_RESULT_INVALID_ARGUMENT'
119      elif return_type == 'MojoTimeTicks':
120        default = '0'
121      else:
122        raise Exception('Unhandled return type: ' + return_type)
123      code << '%s %s = %s;' % (return_type, f.result_param.name, default)
124
125      # Message ID
126      code << 'params[0] = %d;' % f.uid
127      # Parameter pointers
128      cast_template = 'params[%d] = reinterpret_cast<uint32_t>(%s);'
129      for p in f.params:
130        ptr = p.name
131        if p.IsPassedByValue():
132          ptr = '&' + ptr
133        code << cast_template % (p.uid + 1, ptr)
134      # Return value pointer
135      code << cast_template % (num_params - 1, '&' + f.result_param.name)
136
137      code << 'DoMojoCall(params, sizeof(params));'
138      code << 'return %s;' % f.result_param.name
139
140    code << '}'
141    code << ''
142
143  body = code.GetValue()
144  text = template.render(
145    generator_warning=GeneratorWarning(),
146    body=body)
147  out.write(text)
148
149
150# Parameters passed into trusted code are handled differently depending on
151# details of the parameter.  ParamImpl instances encapsulate these differences
152# and are used to generate the code that transfers parameters across the
153# untrusted/trusted boundary.
154class ParamImpl(object):
155  def __init__(self, param):
156    self.param = param
157
158  # Declare whatever variables are needed to handle this particular parameter.
159  def DeclareVars(self, code):
160    raise NotImplementedError()
161
162  # Convert the untrusted representation of the parameter into a trusted
163  # representation, such as a scalar value or a trusted pointer into the
164  # untrusted address space.
165  def ConvertParam(self):
166    raise NotImplementedError()
167
168  # For this particular parameter, what expression should be passed when
169  # invoking the trusted Mojo API function?
170  def CallParam(self):
171    raise NotImplementedError()
172
173  # After invoking the trusted Mojo API function, transfer data back into
174  # untrusted memory.  Overriden for Out and InOut parameters.
175  def CopyOut(self, code):
176    pass
177
178  # Converting array parameters needs to be defered until after the scalar
179  # parameter containing the size of the array has itself been converted.
180  def IsArray(self):
181    return False
182
183
184class ScalarInputImpl(ParamImpl):
185  def DeclareVars(self, code):
186    code << '%s %s_value;' % (self.param.base_type, self.param.name)
187
188  def ConvertParam(self):
189    p = self.param
190    return ('ConvertScalarInput(nap, params[%d], &%s_value)' %
191            (p.uid + 1, p.name))
192
193  def CallParam(self):
194    return '%s_value' % self.param.name
195
196
197class ScalarOutputImpl(ParamImpl):
198  def DeclareVars(self, code):
199    code << '%s volatile* %s_ptr;' % (self.param.base_type, self.param.name)
200    code << '%s %s_value;' % (self.param.base_type, self.param.name)
201
202  def ConvertParam(self):
203    p = self.param
204    return 'ConvertScalarOutput(nap, params[%d], &%s_ptr)' % (p.uid + 1, p.name)
205
206  def CallParam(self):
207    return '&%s_value' % self.param.name
208
209  def CopyOut(self, code):
210    name = self.param.name
211    code << '*%s_ptr = %s_value;' % (name, name)
212
213
214class ScalarInOutImpl(ParamImpl):
215  def DeclareVars(self, code):
216    code << '%s volatile* %s_ptr;' % (self.param.base_type, self.param.name)
217    code << '%s %s_value;' % (self.param.base_type, self.param.name)
218
219  def ConvertParam(self):
220    p = self.param
221    return ('ConvertScalarInOut(nap, params[%d], %s, &%s_value, &%s_ptr)' %
222            (p.uid + 1, CBool(p.is_optional), p.name, p.name))
223
224  def CallParam(self):
225    name = self.param.name
226    expr = '&%s_value' % name
227    if self.param.is_optional:
228      expr = '%s_ptr ? %s : NULL' % (name, expr)
229    return expr
230
231  def CopyOut(self, code):
232    name = self.param.name
233    if self.param.is_optional:
234      code << 'if (%s_ptr != NULL) {' % (name)
235      with code.Indent():
236        code << '*%s_ptr = %s_value;' % (name, name)
237      code << '}'
238    else:
239      code << '*%s_ptr = %s_value;' % (name, name)
240
241
242class ArrayImpl(ParamImpl):
243  def DeclareVars(self, code):
244    code << '%s %s;' % (self.param.param_type, self.param.name)
245
246  def ConvertParam(self):
247    p = self.param
248    if p.base_type == 'void':
249      element_size = '1'
250    else:
251      element_size = 'sizeof(*%s)' % p.name
252
253    return ('ConvertArray(nap, params[%d], %s, %s, %s, &%s)' %
254            (p.uid + 1, p.size + '_value', element_size, CBool(p.is_optional),
255             p.name))
256
257  def CallParam(self):
258    return self.param.name
259
260  def IsArray(self):
261    return True
262
263
264class StructInputImpl(ParamImpl):
265  def DeclareVars(self, code):
266    code << '%s %s;' % (self.param.param_type, self.param.name)
267
268  def ConvertParam(self):
269    p = self.param
270    return ('ConvertStruct(nap, params[%d], %s, &%s)' %
271            (p.uid + 1, CBool(p.is_optional), p.name))
272
273  def CallParam(self):
274    return self.param.name
275
276
277def ImplForParam(p):
278  if p.IsScalar():
279    if p.is_output:
280      if p.is_input:
281        return ScalarInOutImpl(p)
282      else:
283        return ScalarOutputImpl(p)
284    else:
285      return ScalarInputImpl(p)
286  elif p.is_array:
287    return ArrayImpl(p)
288  elif p.is_struct:
289    return StructInputImpl(p)
290  else:
291    assert False, p
292
293
294def CBool(value):
295  return 'true' if value else 'false'
296
297
298# A trusted wrapper that validates the arguments passed from untrusted code
299# before passing them to the underlying public Mojo API.
300def GenerateMojoSyscall(functions, out):
301  template = jinja_env.get_template('mojo_syscall.cc.tmpl')
302
303  code = CodeWriter()
304  code.PushMargin()
305
306  for f in functions:
307    impls = [ImplForParam(p) for p in f.params]
308    impls.append(ImplForParam(f.result_param))
309
310    code << 'case %d:' % f.uid
311
312    code.PushMargin()
313
314    code << '{'
315
316    with code.Indent():
317      num_params = len(f.params) + 2
318      code << 'if (num_params != %d) {' % num_params
319      with code.Indent():
320        code << 'return -1;'
321      code << '}'
322
323      # Declare temporaries.
324      for impl in impls:
325        impl.DeclareVars(code)
326
327      def ConvertParam(code, impl):
328        code << 'if (!%s) {' % impl.ConvertParam()
329        with code.Indent():
330          code << 'return -1;'
331        code << '}'
332
333      code << '{'
334      with code.Indent():
335        code << 'ScopedCopyLock copy_lock(nap);'
336        # Convert and validate pointers in two passes.
337        # Arrays cannot be validated until the size parameter has been
338        # converted.
339        for impl in impls:
340          if not impl.IsArray():
341            ConvertParam(code, impl)
342        for impl in impls:
343          if impl.IsArray():
344            ConvertParam(code, impl)
345      code << '}'
346      code << ''
347
348      # Call
349      getParams = [impl.CallParam() for impl in impls[:-1]]
350      code << 'result_value = %s(%s);' % (f.name, ', '.join(getParams))
351      code << ''
352
353      # Write outputs
354      code << '{'
355      with code.Indent():
356        code << 'ScopedCopyLock copy_lock(nap);'
357        for impl in impls:
358          impl.CopyOut(code)
359      code << '}'
360      code << ''
361
362      code << 'return 0;'
363    code << '}'
364
365    code.PopMargin()
366
367  body = code.GetValue()
368  text = template.render(
369    generator_warning=GeneratorWarning(),
370    body=body)
371  out.write(text)
372
373
374def OutFile(dir_path, name):
375  if not os.path.exists(dir_path):
376    os.makedirs(dir_path)
377  return open(os.path.join(dir_path, name), 'w')
378
379
380def main(args):
381  usage = 'usage: %prog [options]'
382  parser = optparse.OptionParser(usage=usage)
383  parser.add_option(
384      '-d',
385      dest='out_dir',
386      metavar='DIR',
387      help='output generated code into directory DIR')
388  options, args = parser.parse_args(args=args)
389  if not options.out_dir:
390    parser.error('-d is required')
391  if args:
392    parser.error('unexpected positional arguments: %s' % ' '.join(args))
393
394  mojo = interface.MakeInterface()
395
396  out = OutFile(options.out_dir, 'libmojo.cc')
397  GenerateLibMojo(mojo.functions, out)
398
399  out = OutFile(options.out_dir, 'mojo_syscall.cc')
400  GenerateMojoSyscall(mojo.functions, out)
401
402
403if __name__ == '__main__':
404  main(sys.argv[1:])
405