policy_testserver.py revision d0247b1b59f9c528cb6df88b4f2b9afaf80d181e
1# Copyright (c) 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""A bare-bones test server for testing cloud policy support.
6
7This implements a simple cloud policy test server that can be used to test
8chrome's device management service client. The policy information is read from
9the file named device_management in the server's data directory. It contains
10enforced and recommended policies for the device and user scope, and a list
11of managed users.
12
13The format of the file is JSON. The root dictionary contains a list under the
14key "managed_users". It contains auth tokens for which the server will claim
15that the user is managed. The token string "*" indicates that all users are
16claimed to be managed. Other keys in the root dictionary identify request
17scopes. The user-request scope is described by a dictionary that holds two
18sub-dictionaries: "mandatory" and "recommended". Both these hold the policy
19definitions as key/value stores, their format is identical to what the Linux
20implementation reads from /etc.
21The device-scope holds the policy-definition directly as key/value stores in the
22protobuf-format.
23
24Example:
25
26{
27  "google/chromeos/device" : {
28    "guest_mode_enabled" : false
29  },
30  "google/chromeos/user" : {
31    "mandatory" : {
32      "HomepageLocation" : "http://www.chromium.org",
33      "IncognitoEnabled" : false
34    },
35     "recommended" : {
36      "JavascriptEnabled": false
37    }
38  },
39  "google/chromeos/publicaccount/user@example.com" : {
40    "mandatory" : {
41      "HomepageLocation" : "http://www.chromium.org"
42    },
43     "recommended" : {
44    }
45  },
46  "managed_users" : [
47    "secret123456"
48  ],
49  "current_key_index": 0,
50  "robot_api_auth_code": "fake_auth_code",
51  "invalidation_source": 1025,
52  "invalidation_name": "UENUPOL"
53}
54
55"""
56
57import BaseHTTPServer
58import cgi
59import google.protobuf.text_format
60import hashlib
61import logging
62import os
63import random
64import re
65import sys
66import time
67import tlslite
68import tlslite.api
69import tlslite.utils
70import tlslite.utils.cryptomath
71
72# The name and availability of the json module varies in python versions.
73try:
74  import simplejson as json
75except ImportError:
76  try:
77    import json
78  except ImportError:
79    json = None
80
81import asn1der
82import testserver_base
83
84import device_management_backend_pb2 as dm
85import cloud_policy_pb2 as cp
86import chrome_extension_policy_pb2 as ep
87
88# Device policy is only available on Chrome OS builds.
89try:
90  import chrome_device_policy_pb2 as dp
91except ImportError:
92  dp = None
93
94# ASN.1 object identifier for PKCS#1/RSA.
95PKCS1_RSA_OID = '\x2a\x86\x48\x86\xf7\x0d\x01\x01\x01'
96
97# SHA256 sum of "0".
98SHA256_0 = hashlib.sha256('0').digest()
99
100# List of bad machine identifiers that trigger the |valid_serial_number_missing|
101# flag to be set set in the policy fetch response.
102BAD_MACHINE_IDS = [ '123490EN400015' ]
103
104# List of machines that trigger the server to send kiosk enrollment response
105# for the register request.
106KIOSK_MACHINE_IDS = [ 'KIOSK' ]
107
108
109class PolicyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
110  """Decodes and handles device management requests from clients.
111
112  The handler implements all the request parsing and protobuf message decoding
113  and encoding. It calls back into the server to lookup, register, and
114  unregister clients.
115  """
116
117  def __init__(self, request, client_address, server):
118    """Initialize the handler.
119
120    Args:
121      request: The request data received from the client as a string.
122      client_address: The client address.
123      server: The TestServer object to use for (un)registering clients.
124    """
125    BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, request,
126                                                   client_address, server)
127
128  def GetUniqueParam(self, name):
129    """Extracts a unique query parameter from the request.
130
131    Args:
132      name: Names the parameter to fetch.
133    Returns:
134      The parameter value or None if the parameter doesn't exist or is not
135      unique.
136    """
137    if not hasattr(self, '_params'):
138      self._params = cgi.parse_qs(self.path[self.path.find('?') + 1:])
139
140    param_list = self._params.get(name, [])
141    if len(param_list) == 1:
142      return param_list[0]
143    return None
144
145  def do_GET(self):
146    """Handles GET requests.
147
148    Currently this is only used to serve external policy data."""
149    sep = self.path.find('?')
150    path = self.path if sep == -1 else self.path[:sep]
151    if path == '/externalpolicydata':
152      http_response, raw_reply = self.HandleExternalPolicyDataRequest()
153    else:
154      http_response = 404
155      raw_reply = 'Invalid path'
156    self.send_response(http_response)
157    self.end_headers()
158    self.wfile.write(raw_reply)
159
160  def do_POST(self):
161    http_response, raw_reply = self.HandleRequest()
162    self.send_response(http_response)
163    if (http_response == 200):
164      self.send_header('Content-Type', 'application/x-protobuffer')
165    self.end_headers()
166    self.wfile.write(raw_reply)
167
168  def HandleExternalPolicyDataRequest(self):
169    """Handles a request to download policy data for a component."""
170    policy_key = self.GetUniqueParam('key')
171    if not policy_key:
172      return (400, 'Missing key parameter')
173    data = self.server.ReadPolicyDataFromDataDir(policy_key)
174    if data is None:
175      return (404, 'Policy not found for ' + policy_key)
176    return (200, data)
177
178  def HandleRequest(self):
179    """Handles a request.
180
181    Parses the data supplied at construction time and returns a pair indicating
182    http status code and response data to be sent back to the client.
183
184    Returns:
185      A tuple of HTTP status code and response data to send to the client.
186    """
187    rmsg = dm.DeviceManagementRequest()
188    length = int(self.headers.getheader('content-length'))
189    rmsg.ParseFromString(self.rfile.read(length))
190
191    logging.debug('gaia auth token -> ' +
192                  self.headers.getheader('Authorization', ''))
193    logging.debug('oauth token -> ' + str(self.GetUniqueParam('oauth_token')))
194    logging.debug('deviceid -> ' + str(self.GetUniqueParam('deviceid')))
195    self.DumpMessage('Request', rmsg)
196
197    request_type = self.GetUniqueParam('request')
198    # Check server side requirements, as defined in
199    # device_management_backend.proto.
200    if (self.GetUniqueParam('devicetype') != '2' or
201        self.GetUniqueParam('apptype') != 'Chrome' or
202        (request_type != 'ping' and
203         len(self.GetUniqueParam('deviceid')) >= 64) or
204        len(self.GetUniqueParam('agent')) >= 64):
205      return (400, 'Invalid request parameter')
206    if request_type == 'register':
207      return self.ProcessRegister(rmsg.register_request)
208    if request_type == 'api_authorization':
209      return self.ProcessApiAuthorization(rmsg.service_api_access_request)
210    elif request_type == 'unregister':
211      return self.ProcessUnregister(rmsg.unregister_request)
212    elif request_type == 'policy' or request_type == 'ping':
213      return self.ProcessPolicy(rmsg.policy_request, request_type)
214    elif request_type == 'enterprise_check':
215      return self.ProcessAutoEnrollment(rmsg.auto_enrollment_request)
216    else:
217      return (400, 'Invalid request parameter')
218
219  def CreatePolicyForExternalPolicyData(self, policy_key):
220    """Returns an ExternalPolicyData protobuf for policy_key.
221
222    If there is policy data for policy_key then the download url will be
223    set so that it points to that data, and the appropriate hash is also set.
224    Otherwise, the protobuf will be empty.
225
226    Args:
227      policy_key: the policy type and settings entity id, joined by '/'.
228
229    Returns:
230      A serialized ExternalPolicyData.
231    """
232    settings = ep.ExternalPolicyData()
233    data = self.server.ReadPolicyDataFromDataDir(policy_key)
234    if data:
235      settings.download_url = ('http://%s:%s/externalpolicydata?key=%s' %
236                                  (self.server.server_name,
237                                   self.server.server_port,
238                                   policy_key) )
239      settings.secure_hash = hashlib.sha1(data).digest()
240    return settings.SerializeToString()
241
242  def CheckGoogleLogin(self):
243    """Extracts the auth token from the request and returns it. The token may
244    either be a GoogleLogin token from an Authorization header, or an OAuth V2
245    token from the oauth_token query parameter. Returns None if no token is
246    present.
247    """
248    oauth_token = self.GetUniqueParam('oauth_token')
249    if oauth_token:
250      return oauth_token
251
252    match = re.match('GoogleLogin auth=(\\w+)',
253                     self.headers.getheader('Authorization', ''))
254    if match:
255      return match.group(1)
256
257    return None
258
259  def ProcessRegister(self, msg):
260    """Handles a register request.
261
262    Checks the query for authorization and device identifier, registers the
263    device with the server and constructs a response.
264
265    Args:
266      msg: The DeviceRegisterRequest message received from the client.
267
268    Returns:
269      A tuple of HTTP status code and response data to send to the client.
270    """
271    # Check the auth token and device ID.
272    auth = self.CheckGoogleLogin()
273    if not auth:
274      return (403, 'No authorization')
275
276    policy = self.server.GetPolicies()
277    if ('*' not in policy['managed_users'] and
278        auth not in policy['managed_users']):
279      return (403, 'Unmanaged')
280
281    device_id = self.GetUniqueParam('deviceid')
282    if not device_id:
283      return (400, 'Missing device identifier')
284
285    token_info = self.server.RegisterDevice(device_id,
286                                             msg.machine_id,
287                                             msg.type)
288
289    # Send back the reply.
290    response = dm.DeviceManagementResponse()
291    response.register_response.device_management_token = (
292        token_info['device_token'])
293    response.register_response.machine_name = token_info['machine_name']
294    response.register_response.enrollment_type = token_info['enrollment_mode']
295
296    self.DumpMessage('Response', response)
297
298    return (200, response.SerializeToString())
299
300  def ProcessApiAuthorization(self, msg):
301    """Handles an API authorization request.
302
303    Args:
304      msg: The DeviceServiceApiAccessRequest message received from the client.
305
306    Returns:
307      A tuple of HTTP status code and response data to send to the client.
308    """
309    policy = self.server.GetPolicies()
310
311    # Return the auth code from the config file if it's defined,
312    # else return a descriptive default value.
313    response = dm.DeviceManagementResponse()
314    response.service_api_access_response.auth_code = policy.get(
315        'robot_api_auth_code', 'policy_testserver.py-auth_code')
316    self.DumpMessage('Response', response)
317
318    return (200, response.SerializeToString())
319
320  def ProcessUnregister(self, msg):
321    """Handles a register request.
322
323    Checks for authorization, unregisters the device and constructs the
324    response.
325
326    Args:
327      msg: The DeviceUnregisterRequest message received from the client.
328
329    Returns:
330      A tuple of HTTP status code and response data to send to the client.
331    """
332    # Check the management token.
333    token, response = self.CheckToken()
334    if not token:
335      return response
336
337    # Unregister the device.
338    self.server.UnregisterDevice(token['device_token'])
339
340    # Prepare and send the response.
341    response = dm.DeviceManagementResponse()
342    response.unregister_response.CopyFrom(dm.DeviceUnregisterResponse())
343
344    self.DumpMessage('Response', response)
345
346    return (200, response.SerializeToString())
347
348  def ProcessPolicy(self, msg, request_type):
349    """Handles a policy request.
350
351    Checks for authorization, encodes the policy into protobuf representation
352    and constructs the response.
353
354    Args:
355      msg: The DevicePolicyRequest message received from the client.
356
357    Returns:
358      A tuple of HTTP status code and response data to send to the client.
359    """
360    token_info, error = self.CheckToken()
361    if not token_info:
362      return error
363
364    response = dm.DeviceManagementResponse()
365    for request in msg.request:
366      fetch_response = response.policy_response.response.add()
367      if (request.policy_type in
368             ('google/chrome/user',
369              'google/chromeos/user',
370              'google/chromeos/device',
371              'google/chromeos/publicaccount',
372              'google/chrome/extension')):
373        if request_type != 'policy':
374          fetch_response.error_code = 400
375          fetch_response.error_message = 'Invalid request type'
376        else:
377          self.ProcessCloudPolicy(request, token_info, fetch_response)
378      else:
379        fetch_response.error_code = 400
380        fetch_response.error_message = 'Invalid policy_type'
381
382    return (200, response.SerializeToString())
383
384  def ProcessAutoEnrollment(self, msg):
385    """Handles an auto-enrollment check request.
386
387    The reply depends on the value of the modulus:
388      1: replies with no new modulus and the sha256 hash of "0"
389      2: replies with a new modulus, 4.
390      4: replies with a new modulus, 2.
391      8: fails with error 400.
392      16: replies with a new modulus, 16.
393      32: replies with a new modulus, 1.
394      anything else: replies with no new modulus and an empty list of hashes
395
396    These allow the client to pick the testing scenario its wants to simulate.
397
398    Args:
399      msg: The DeviceAutoEnrollmentRequest message received from the client.
400
401    Returns:
402      A tuple of HTTP status code and response data to send to the client.
403    """
404    auto_enrollment_response = dm.DeviceAutoEnrollmentResponse()
405
406    if msg.modulus == 1:
407      auto_enrollment_response.hash.append(SHA256_0)
408    elif msg.modulus == 2:
409      auto_enrollment_response.expected_modulus = 4
410    elif msg.modulus == 4:
411      auto_enrollment_response.expected_modulus = 2
412    elif msg.modulus == 8:
413      return (400, 'Server error')
414    elif msg.modulus == 16:
415      auto_enrollment_response.expected_modulus = 16
416    elif msg.modulus == 32:
417      auto_enrollment_response.expected_modulus = 1
418
419    response = dm.DeviceManagementResponse()
420    response.auto_enrollment_response.CopyFrom(auto_enrollment_response)
421    return (200, response.SerializeToString())
422
423  def SetProtobufMessageField(self, group_message, field, field_value):
424    '''Sets a field in a protobuf message.
425
426    Args:
427      group_message: The protobuf message.
428      field: The field of the message to set, it should be a member of
429          group_message.DESCRIPTOR.fields.
430      field_value: The value to set.
431    '''
432    if field.label == field.LABEL_REPEATED:
433      assert type(field_value) == list
434      entries = group_message.__getattribute__(field.name)
435      if field.message_type is None:
436        for list_item in field_value:
437          entries.append(list_item)
438      else:
439        # This field is itself a protobuf.
440        sub_type = field.message_type
441        for sub_value in field_value:
442          assert type(sub_value) == dict
443          # Add a new sub-protobuf per list entry.
444          sub_message = entries.add()
445          # Now iterate over its fields and recursively add them.
446          for sub_field in sub_message.DESCRIPTOR.fields:
447            if sub_field.name in sub_value:
448              value = sub_value[sub_field.name]
449              self.SetProtobufMessageField(sub_message, sub_field, value)
450      return
451    elif field.type == field.TYPE_BOOL:
452      assert type(field_value) == bool
453    elif field.type == field.TYPE_STRING:
454      assert type(field_value) == str or type(field_value) == unicode
455    elif field.type == field.TYPE_INT64:
456      assert type(field_value) == int
457    elif (field.type == field.TYPE_MESSAGE and
458          field.message_type.name == 'StringList'):
459      assert type(field_value) == list
460      entries = group_message.__getattribute__(field.name).entries
461      for list_item in field_value:
462        entries.append(list_item)
463      return
464    else:
465      raise Exception('Unknown field type %s' % field.type)
466    group_message.__setattr__(field.name, field_value)
467
468  def GatherDevicePolicySettings(self, settings, policies):
469    '''Copies all the policies from a dictionary into a protobuf of type
470    CloudDeviceSettingsProto.
471
472    Args:
473      settings: The destination ChromeDeviceSettingsProto protobuf.
474      policies: The source dictionary containing policies in JSON format.
475    '''
476    for group in settings.DESCRIPTOR.fields:
477      # Create protobuf message for group.
478      group_message = eval('dp.' + group.message_type.name + '()')
479      # Indicates if at least one field was set in |group_message|.
480      got_fields = False
481      # Iterate over fields of the message and feed them from the
482      # policy config file.
483      for field in group_message.DESCRIPTOR.fields:
484        field_value = None
485        if field.name in policies:
486          got_fields = True
487          field_value = policies[field.name]
488          self.SetProtobufMessageField(group_message, field, field_value)
489      if got_fields:
490        settings.__getattribute__(group.name).CopyFrom(group_message)
491
492  def GatherUserPolicySettings(self, settings, policies):
493    '''Copies all the policies from a dictionary into a protobuf of type
494    CloudPolicySettings.
495
496    Args:
497      settings: The destination: a CloudPolicySettings protobuf.
498      policies: The source: a dictionary containing policies under keys
499          'recommended' and 'mandatory'.
500    '''
501    for field in settings.DESCRIPTOR.fields:
502      # |field| is the entry for a specific policy in the top-level
503      # CloudPolicySettings proto.
504
505      # Look for this policy's value in the mandatory or recommended dicts.
506      if field.name in policies.get('mandatory', {}):
507        mode = cp.PolicyOptions.MANDATORY
508        value = policies['mandatory'][field.name]
509      elif field.name in policies.get('recommended', {}):
510        mode = cp.PolicyOptions.RECOMMENDED
511        value = policies['recommended'][field.name]
512      else:
513        continue
514
515      # Create protobuf message for this policy.
516      policy_message = eval('cp.' + field.message_type.name + '()')
517      policy_message.policy_options.mode = mode
518      field_descriptor = policy_message.DESCRIPTOR.fields_by_name['value']
519      self.SetProtobufMessageField(policy_message, field_descriptor, value)
520      settings.__getattribute__(field.name).CopyFrom(policy_message)
521
522  def ProcessCloudPolicy(self, msg, token_info, response):
523    """Handles a cloud policy request. (New protocol for policy requests.)
524
525    Encodes the policy into protobuf representation, signs it and constructs
526    the response.
527
528    Args:
529      msg: The CloudPolicyRequest message received from the client.
530      token_info: the token extracted from the request.
531      response: A PolicyFetchResponse message that should be filled with the
532                response data.
533    """
534
535    if msg.machine_id:
536      self.server.UpdateMachineId(token_info['device_token'], msg.machine_id)
537
538    # Response is only given if the scope is specified in the config file.
539    # Normally 'google/chromeos/device', 'google/chromeos/user' and
540    # 'google/chromeos/publicaccount' should be accepted.
541    policy = self.server.GetPolicies()
542    policy_value = ''
543    policy_key = msg.policy_type
544    if msg.settings_entity_id:
545      policy_key += '/' + msg.settings_entity_id
546    if msg.policy_type in token_info['allowed_policy_types']:
547      if (msg.policy_type == 'google/chromeos/user' or
548          msg.policy_type == 'google/chrome/user' or
549          msg.policy_type == 'google/chromeos/publicaccount'):
550        settings = cp.CloudPolicySettings()
551        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
552        if payload is None:
553          self.GatherUserPolicySettings(settings, policy.get(policy_key, {}))
554          payload = settings.SerializeToString()
555      elif dp is not None and msg.policy_type == 'google/chromeos/device':
556        settings = dp.ChromeDeviceSettingsProto()
557        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
558        if payload is None:
559          self.GatherDevicePolicySettings(settings, policy.get(policy_key, {}))
560          payload = settings.SerializeToString()
561      elif msg.policy_type == 'google/chrome/extension':
562        settings = ep.ExternalPolicyData()
563        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
564        if payload is None:
565          payload = self.CreatePolicyForExternalPolicyData(policy_key)
566      else:
567        response.error_code = 400
568        response.error_message = 'Invalid policy type'
569        return
570    else:
571      response.error_code = 400
572      response.error_message = 'Request not allowed for the token used'
573      return
574
575    # Sign with 'current_key_index', defaulting to key 0.
576    signing_key = None
577    req_key = None
578    current_key_index = policy.get('current_key_index', 0)
579    nkeys = len(self.server.keys)
580    if (msg.signature_type == dm.PolicyFetchRequest.SHA1_RSA and
581        current_key_index in range(nkeys)):
582      signing_key = self.server.keys[current_key_index]
583      if msg.public_key_version in range(1, nkeys + 1):
584        # requested key exists, use for signing and rotate.
585        req_key = self.server.keys[msg.public_key_version - 1]['private_key']
586
587    # Fill the policy data protobuf.
588    policy_data = dm.PolicyData()
589    policy_data.policy_type = msg.policy_type
590    policy_data.timestamp = int(time.time() * 1000)
591    policy_data.request_token = token_info['device_token']
592    policy_data.policy_value = payload
593    policy_data.machine_name = token_info['machine_name']
594    policy_data.valid_serial_number_missing = (
595        token_info['machine_id'] in BAD_MACHINE_IDS)
596    policy_data.settings_entity_id = msg.settings_entity_id
597    policy_data.service_account_identity = policy.get(
598        'service_account_identity',
599        'policy_testserver.py-service_account_identity')
600    invalidation_source = policy.get('invalidation_source')
601    if invalidation_source is not None:
602      policy_data.invalidation_source = invalidation_source
603    # Since invalidation_name is type bytes in the proto, the Unicode name
604    # provided needs to be encoded as ASCII to set the correct byte pattern.
605    invalidation_name = policy.get('invalidation_name')
606    if invalidation_name is not None:
607      policy_data.invalidation_name = invalidation_name.encode('ascii')
608
609    if signing_key:
610      policy_data.public_key_version = current_key_index + 1
611    if msg.policy_type == 'google/chromeos/publicaccount':
612      policy_data.username = msg.settings_entity_id
613    else:
614      # For regular user/device policy, there is no way for the testserver to
615      # know the user name belonging to the GAIA auth token we received (short
616      # of actually talking to GAIA). To address this, we read the username from
617      # the policy configuration dictionary, or use a default.
618      policy_data.username = policy.get('policy_user', 'user@example.com')
619    policy_data.device_id = token_info['device_id']
620    signed_data = policy_data.SerializeToString()
621
622    response.policy_data = signed_data
623    if signing_key:
624      response.policy_data_signature = (
625          signing_key['private_key'].hashAndSign(signed_data).tostring())
626      if msg.public_key_version != current_key_index + 1:
627        response.new_public_key = signing_key['public_key']
628        if req_key:
629          response.new_public_key_signature = (
630              req_key.hashAndSign(response.new_public_key).tostring())
631
632    self.DumpMessage('Response', response)
633
634    return (200, response.SerializeToString())
635
636  def CheckToken(self):
637    """Helper for checking whether the client supplied a valid DM token.
638
639    Extracts the token from the request and passed to the server in order to
640    look up the client.
641
642    Returns:
643      A pair of token information record and error response. If the first
644      element is None, then the second contains an error code to send back to
645      the client. Otherwise the first element is the same structure that is
646      returned by LookupToken().
647    """
648    error = 500
649    dmtoken = None
650    request_device_id = self.GetUniqueParam('deviceid')
651    match = re.match('GoogleDMToken token=(\\w+)',
652                     self.headers.getheader('Authorization', ''))
653    if match:
654      dmtoken = match.group(1)
655    if not dmtoken:
656      error = 401
657    else:
658      token_info = self.server.LookupToken(dmtoken)
659      if (not token_info or
660          not request_device_id or
661          token_info['device_id'] != request_device_id):
662        error = 410
663      else:
664        return (token_info, None)
665
666    logging.debug('Token check failed with error %d' % error)
667
668    return (None, (error, 'Server error %d' % error))
669
670  def DumpMessage(self, label, msg):
671    """Helper for logging an ASCII dump of a protobuf message."""
672    logging.debug('%s\n%s' % (label, str(msg)))
673
674
675class PolicyTestServer(testserver_base.ClientRestrictingServerMixIn,
676                       testserver_base.BrokenPipeHandlerMixIn,
677                       testserver_base.StoppableHTTPServer):
678  """Handles requests and keeps global service state."""
679
680  def __init__(self, server_address, data_dir, policy_path, client_state_file,
681               private_key_paths):
682    """Initializes the server.
683
684    Args:
685      server_address: Server host and port.
686      policy_path: Names the file to read JSON-formatted policy from.
687      private_key_paths: List of paths to read private keys from.
688    """
689    testserver_base.StoppableHTTPServer.__init__(self, server_address,
690                                                 PolicyRequestHandler)
691    self._registered_tokens = {}
692    self.data_dir = data_dir
693    self.policy_path = policy_path
694    self.client_state_file = client_state_file
695
696    self.keys = []
697    if private_key_paths:
698      # Load specified keys from the filesystem.
699      for key_path in private_key_paths:
700        try:
701          key_str = open(key_path).read()
702        except IOError:
703          print 'Failed to load private key from %s' % key_path
704          continue
705
706        try:
707          key = tlslite.api.parsePEMKey(key_str, private=True)
708        except SyntaxError:
709          key = tlslite.utils.Python_RSAKey.Python_RSAKey._parsePKCS8(
710              tlslite.utils.cryptomath.stringToBytes(key_str))
711
712        assert key is not None
713        self.keys.append({ 'private_key' : key })
714    else:
715      # Generate 2 private keys if none were passed from the command line.
716      for i in range(2):
717        key = tlslite.api.generateRSAKey(512)
718        assert key is not None
719        self.keys.append({ 'private_key' : key })
720
721    # Derive the public keys from the private keys.
722    for entry in self.keys:
723      key = entry['private_key']
724
725      algorithm = asn1der.Sequence(
726          [ asn1der.Data(asn1der.OBJECT_IDENTIFIER, PKCS1_RSA_OID),
727            asn1der.Data(asn1der.NULL, '') ])
728      rsa_pubkey = asn1der.Sequence([ asn1der.Integer(key.n),
729                                      asn1der.Integer(key.e) ])
730      pubkey = asn1der.Sequence([ algorithm, asn1der.Bitstring(rsa_pubkey) ])
731      entry['public_key'] = pubkey
732
733    # Load client state.
734    if self.client_state_file is not None:
735      try:
736        file_contents = open(self.client_state_file).read()
737        self._registered_tokens = json.loads(file_contents)
738      except IOError:
739        pass
740
741  def GetPolicies(self):
742    """Returns the policies to be used, reloaded form the backend file every
743       time this is called.
744    """
745    policy = {}
746    if json is None:
747      print 'No JSON module, cannot parse policy information'
748    else :
749      try:
750        policy = json.loads(open(self.policy_path).read())
751      except IOError:
752        print 'Failed to load policy from %s' % self.policy_path
753    return policy
754
755  def RegisterDevice(self, device_id, machine_id, type):
756    """Registers a device or user and generates a DM token for it.
757
758    Args:
759      device_id: The device identifier provided by the client.
760
761    Returns:
762      The newly generated device token for the device.
763    """
764    dmtoken_chars = []
765    while len(dmtoken_chars) < 32:
766      dmtoken_chars.append(random.choice('0123456789abcdef'))
767    dmtoken = ''.join(dmtoken_chars)
768    allowed_policy_types = {
769      dm.DeviceRegisterRequest.BROWSER: ['google/chrome/user'],
770      dm.DeviceRegisterRequest.USER: [
771          'google/chromeos/user',
772          'google/chrome/extension'
773      ],
774      dm.DeviceRegisterRequest.DEVICE: [
775          'google/chromeos/device',
776          'google/chromeos/publicaccount'
777      ],
778      dm.DeviceRegisterRequest.TT: ['google/chromeos/user',
779                                    'google/chrome/user'],
780    }
781    if machine_id in KIOSK_MACHINE_IDS:
782      enrollment_mode = dm.DeviceRegisterResponse.RETAIL
783    else:
784      enrollment_mode = dm.DeviceRegisterResponse.ENTERPRISE
785    self._registered_tokens[dmtoken] = {
786      'device_id': device_id,
787      'device_token': dmtoken,
788      'allowed_policy_types': allowed_policy_types[type],
789      'machine_name': 'chromeos-' + machine_id,
790      'machine_id': machine_id,
791      'enrollment_mode': enrollment_mode,
792    }
793    self.WriteClientState()
794    return self._registered_tokens[dmtoken]
795
796  def UpdateMachineId(self, dmtoken, machine_id):
797    """Updates the machine identifier for a registered device.
798
799    Args:
800      dmtoken: The device management token provided by the client.
801      machine_id: Updated hardware identifier value.
802    """
803    if dmtoken in self._registered_tokens:
804      self._registered_tokens[dmtoken]['machine_id'] = machine_id
805      self.WriteClientState()
806
807  def LookupToken(self, dmtoken):
808    """Looks up a device or a user by DM token.
809
810    Args:
811      dmtoken: The device management token provided by the client.
812
813    Returns:
814      A dictionary with information about a device or user that is registered by
815      dmtoken, or None if the token is not found.
816    """
817    return self._registered_tokens.get(dmtoken, None)
818
819  def UnregisterDevice(self, dmtoken):
820    """Unregisters a device identified by the given DM token.
821
822    Args:
823      dmtoken: The device management token provided by the client.
824    """
825    if dmtoken in self._registered_tokens.keys():
826      del self._registered_tokens[dmtoken]
827      self.WriteClientState()
828
829  def WriteClientState(self):
830    """Writes the client state back to the file."""
831    if self.client_state_file is not None:
832      json_data = json.dumps(self._registered_tokens)
833      open(self.client_state_file, 'w').write(json_data)
834
835  def GetBaseFilename(self, policy_selector):
836    """Returns the base filename for the given policy_selector.
837
838    Args:
839      policy_selector: the policy type and settings entity id, joined by '/'.
840
841    Returns:
842      The filename corresponding to the policy_selector, without a file
843      extension.
844    """
845    sanitized_policy_selector = re.sub('[^A-Za-z0-9.@-]', '_', policy_selector)
846    return os.path.join(self.data_dir or '',
847                        'policy_%s' % sanitized_policy_selector)
848
849  def ReadPolicyFromDataDir(self, policy_selector, proto_message):
850    """Tries to read policy payload from a file in the data directory.
851
852    First checks for a binary rendition of the policy protobuf in
853    <data_dir>/policy_<sanitized_policy_selector>.bin. If that exists, returns
854    it. If that file doesn't exist, tries
855    <data_dir>/policy_<sanitized_policy_selector>.txt and decodes that as a
856    protobuf using proto_message. If that fails as well, returns None.
857
858    Args:
859      policy_selector: Selects which policy to read.
860      proto_message: Optional protobuf message object used for decoding the
861          proto text format.
862
863    Returns:
864      The binary payload message, or None if not found.
865    """
866    base_filename = self.GetBaseFilename(policy_selector)
867
868    # Try the binary payload file first.
869    try:
870      return open(base_filename + '.bin').read()
871    except IOError:
872      pass
873
874    # If that fails, try the text version instead.
875    if proto_message is None:
876      return None
877
878    try:
879      text = open(base_filename + '.txt').read()
880      google.protobuf.text_format.Merge(text, proto_message)
881      return proto_message.SerializeToString()
882    except IOError:
883      return None
884    except google.protobuf.text_format.ParseError:
885      return None
886
887  def ReadPolicyDataFromDataDir(self, policy_selector):
888    """Returns the external policy data for |policy_selector| if found.
889
890    Args:
891      policy_selector: Selects which policy to read.
892
893    Returns:
894      The data for the corresponding policy type and entity id, if found.
895    """
896    base_filename = self.GetBaseFilename(policy_selector)
897    try:
898      return open(base_filename + '.data').read()
899    except IOError:
900      return None
901
902
903class PolicyServerRunner(testserver_base.TestServerRunner):
904
905  def __init__(self):
906    super(PolicyServerRunner, self).__init__()
907
908  def create_server(self, server_data):
909    data_dir = self.options.data_dir or ''
910    config_file = (self.options.config_file or
911                   os.path.join(data_dir, 'device_management'))
912    server = PolicyTestServer((self.options.host, self.options.port),
913                              data_dir, config_file,
914                              self.options.client_state_file,
915                              self.options.policy_keys)
916    server_data['port'] = server.server_port
917    return server
918
919  def add_options(self):
920    testserver_base.TestServerRunner.add_options(self)
921    self.option_parser.add_option('--client-state', dest='client_state_file',
922                                  help='File that client state should be '
923                                  'persisted to. This allows the server to be '
924                                  'seeded by a list of pre-registered clients '
925                                  'and restarts without abandoning registered '
926                                  'clients.')
927    self.option_parser.add_option('--policy-key', action='append',
928                                  dest='policy_keys',
929                                  help='Specify a path to a PEM-encoded '
930                                  'private key to use for policy signing. May '
931                                  'be specified multiple times in order to '
932                                  'load multipe keys into the server. If the '
933                                  'server has multiple keys, it will rotate '
934                                  'through them in at each request in a '
935                                  'round-robin fashion. The server will '
936                                  'generate a random key if none is specified '
937                                  'on the command line.')
938    self.option_parser.add_option('--log-level', dest='log_level',
939                                  default='WARN',
940                                  help='Log level threshold to use.')
941    self.option_parser.add_option('--config-file', dest='config_file',
942                                  help='Specify a configuration file to use '
943                                  'instead of the default '
944                                  '<data_dir>/device_management')
945
946  def run_server(self):
947    logger = logging.getLogger()
948    logger.setLevel(getattr(logging, str(self.options.log_level).upper()))
949    if (self.options.log_to_console):
950      logger.addHandler(logging.StreamHandler())
951    if (self.options.log_file):
952      logger.addHandler(logging.FileHandler(self.options.log_file))
953
954    testserver_base.TestServerRunner.run_server(self)
955
956
957if __name__ == '__main__':
958  sys.exit(PolicyServerRunner().main())
959