policy_testserver.py revision bb1529ce867d8845a77ec7cdf3e3003ef1771a40
1# Copyright (c) 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""A bare-bones test server for testing cloud policy support.
6
7This implements a simple cloud policy test server that can be used to test
8chrome's device management service client. The policy information is read from
9the file named device_management in the server's data directory. It contains
10enforced and recommended policies for the device and user scope, and a list
11of managed users.
12
13The format of the file is JSON. The root dictionary contains a list under the
14key "managed_users". It contains auth tokens for which the server will claim
15that the user is managed. The token string "*" indicates that all users are
16claimed to be managed. Other keys in the root dictionary identify request
17scopes. The user-request scope is described by a dictionary that holds two
18sub-dictionaries: "mandatory" and "recommended". Both these hold the policy
19definitions as key/value stores, their format is identical to what the Linux
20implementation reads from /etc.
21The device-scope holds the policy-definition directly as key/value stores in the
22protobuf-format.
23
24Example:
25
26{
27  "google/chromeos/device" : {
28    "guest_mode_enabled" : false
29  },
30  "google/chromeos/user" : {
31    "mandatory" : {
32      "HomepageLocation" : "http://www.chromium.org",
33      "IncognitoEnabled" : false
34    },
35     "recommended" : {
36      "JavascriptEnabled": false
37    }
38  },
39  "google/chromeos/publicaccount/user@example.com" : {
40    "mandatory" : {
41      "HomepageLocation" : "http://www.chromium.org"
42    },
43     "recommended" : {
44    }
45  },
46  "managed_users" : [
47    "secret123456"
48  ],
49  "current_key_index": 0,
50  "robot_api_auth_code": "fake_auth_code"
51}
52
53"""
54
55import BaseHTTPServer
56import cgi
57import google.protobuf.text_format
58import hashlib
59import logging
60import os
61import random
62import re
63import sys
64import time
65import tlslite
66import tlslite.api
67import tlslite.utils
68import tlslite.utils.cryptomath
69
70# The name and availability of the json module varies in python versions.
71try:
72  import simplejson as json
73except ImportError:
74  try:
75    import json
76  except ImportError:
77    json = None
78
79import asn1der
80import testserver_base
81
82import device_management_backend_pb2 as dm
83import cloud_policy_pb2 as cp
84import chrome_extension_policy_pb2 as ep
85
86# Device policy is only available on Chrome OS builds.
87try:
88  import chrome_device_policy_pb2 as dp
89except ImportError:
90  dp = None
91
92# ASN.1 object identifier for PKCS#1/RSA.
93PKCS1_RSA_OID = '\x2a\x86\x48\x86\xf7\x0d\x01\x01\x01'
94
95# SHA256 sum of "0".
96SHA256_0 = hashlib.sha256('0').digest()
97
98# List of bad machine identifiers that trigger the |valid_serial_number_missing|
99# flag to be set set in the policy fetch response.
100BAD_MACHINE_IDS = [ '123490EN400015' ]
101
102# List of machines that trigger the server to send kiosk enrollment response
103# for the register request.
104KIOSK_MACHINE_IDS = [ 'KIOSK' ]
105
106
107class PolicyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
108  """Decodes and handles device management requests from clients.
109
110  The handler implements all the request parsing and protobuf message decoding
111  and encoding. It calls back into the server to lookup, register, and
112  unregister clients.
113  """
114
115  def __init__(self, request, client_address, server):
116    """Initialize the handler.
117
118    Args:
119      request: The request data received from the client as a string.
120      client_address: The client address.
121      server: The TestServer object to use for (un)registering clients.
122    """
123    BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, request,
124                                                   client_address, server)
125
126  def GetUniqueParam(self, name):
127    """Extracts a unique query parameter from the request.
128
129    Args:
130      name: Names the parameter to fetch.
131    Returns:
132      The parameter value or None if the parameter doesn't exist or is not
133      unique.
134    """
135    if not hasattr(self, '_params'):
136      self._params = cgi.parse_qs(self.path[self.path.find('?') + 1:])
137
138    param_list = self._params.get(name, [])
139    if len(param_list) == 1:
140      return param_list[0]
141    return None
142
143  def do_GET(self):
144    """Handles GET requests.
145
146    Currently this is only used to serve external policy data."""
147    sep = self.path.find('?')
148    path = self.path if sep == -1 else self.path[:sep]
149    if path == '/externalpolicydata':
150      http_response, raw_reply = self.HandleExternalPolicyDataRequest()
151    else:
152      http_response = 404
153      raw_reply = 'Invalid path'
154    self.send_response(http_response)
155    self.end_headers()
156    self.wfile.write(raw_reply)
157
158  def do_POST(self):
159    http_response, raw_reply = self.HandleRequest()
160    self.send_response(http_response)
161    if (http_response == 200):
162      self.send_header('Content-Type', 'application/x-protobuffer')
163    self.end_headers()
164    self.wfile.write(raw_reply)
165
166  def HandleExternalPolicyDataRequest(self):
167    """Handles a request to download policy data for a component."""
168    policy_key = self.GetUniqueParam('key')
169    if not policy_key:
170      return (400, 'Missing key parameter')
171    data = self.server.ReadPolicyDataFromDataDir(policy_key)
172    if data is None:
173      return (404, 'Policy not found for ' + policy_key)
174    return (200, data)
175
176  def HandleRequest(self):
177    """Handles a request.
178
179    Parses the data supplied at construction time and returns a pair indicating
180    http status code and response data to be sent back to the client.
181
182    Returns:
183      A tuple of HTTP status code and response data to send to the client.
184    """
185    rmsg = dm.DeviceManagementRequest()
186    length = int(self.headers.getheader('content-length'))
187    rmsg.ParseFromString(self.rfile.read(length))
188
189    logging.debug('gaia auth token -> ' +
190                  self.headers.getheader('Authorization', ''))
191    logging.debug('oauth token -> ' + str(self.GetUniqueParam('oauth_token')))
192    logging.debug('deviceid -> ' + str(self.GetUniqueParam('deviceid')))
193    self.DumpMessage('Request', rmsg)
194
195    request_type = self.GetUniqueParam('request')
196    # Check server side requirements, as defined in
197    # device_management_backend.proto.
198    if (self.GetUniqueParam('devicetype') != '2' or
199        self.GetUniqueParam('apptype') != 'Chrome' or
200        (request_type != 'ping' and
201         len(self.GetUniqueParam('deviceid')) >= 64) or
202        len(self.GetUniqueParam('agent')) >= 64):
203      return (400, 'Invalid request parameter')
204    if request_type == 'register':
205      return self.ProcessRegister(rmsg.register_request)
206    if request_type == 'api_authorization':
207      return self.ProcessApiAuthorization(rmsg.service_api_access_request)
208    elif request_type == 'unregister':
209      return self.ProcessUnregister(rmsg.unregister_request)
210    elif request_type == 'policy' or request_type == 'ping':
211      return self.ProcessPolicy(rmsg.policy_request, request_type)
212    elif request_type == 'enterprise_check':
213      return self.ProcessAutoEnrollment(rmsg.auto_enrollment_request)
214    else:
215      return (400, 'Invalid request parameter')
216
217  def CreatePolicyForExternalPolicyData(self, policy_key):
218    """Returns an ExternalPolicyData protobuf for policy_key.
219
220    If there is policy data for policy_key then the download url will be
221    set so that it points to that data, and the appropriate hash is also set.
222    Otherwise, the protobuf will be empty.
223
224    Args:
225      policy_key: the policy type and settings entity id, joined by '/'.
226
227    Returns:
228      A serialized ExternalPolicyData.
229    """
230    settings = ep.ExternalPolicyData()
231    data = self.server.ReadPolicyDataFromDataDir(policy_key)
232    if data:
233      settings.download_url = ('http://%s:%s/externalpolicydata?key=%s' %
234                                  (self.server.server_name,
235                                   self.server.server_port,
236                                   policy_key) )
237      settings.secure_hash = hashlib.sha1(data).digest()
238    return settings.SerializeToString()
239
240  def CheckGoogleLogin(self):
241    """Extracts the auth token from the request and returns it. The token may
242    either be a GoogleLogin token from an Authorization header, or an OAuth V2
243    token from the oauth_token query parameter. Returns None if no token is
244    present.
245    """
246    oauth_token = self.GetUniqueParam('oauth_token')
247    if oauth_token:
248      return oauth_token
249
250    match = re.match('GoogleLogin auth=(\\w+)',
251                     self.headers.getheader('Authorization', ''))
252    if match:
253      return match.group(1)
254
255    return None
256
257  def ProcessRegister(self, msg):
258    """Handles a register request.
259
260    Checks the query for authorization and device identifier, registers the
261    device with the server and constructs a response.
262
263    Args:
264      msg: The DeviceRegisterRequest message received from the client.
265
266    Returns:
267      A tuple of HTTP status code and response data to send to the client.
268    """
269    # Check the auth token and device ID.
270    auth = self.CheckGoogleLogin()
271    if not auth:
272      return (403, 'No authorization')
273
274    policy = self.server.GetPolicies()
275    if ('*' not in policy['managed_users'] and
276        auth not in policy['managed_users']):
277      return (403, 'Unmanaged')
278
279    device_id = self.GetUniqueParam('deviceid')
280    if not device_id:
281      return (400, 'Missing device identifier')
282
283    token_info = self.server.RegisterDevice(device_id,
284                                             msg.machine_id,
285                                             msg.type)
286
287    # Send back the reply.
288    response = dm.DeviceManagementResponse()
289    response.register_response.device_management_token = (
290        token_info['device_token'])
291    response.register_response.machine_name = token_info['machine_name']
292    response.register_response.enrollment_type = token_info['enrollment_mode']
293
294    self.DumpMessage('Response', response)
295
296    return (200, response.SerializeToString())
297
298  def ProcessApiAuthorization(self, msg):
299    """Handles an API authorization request.
300
301    Args:
302      msg: The DeviceServiceApiAccessRequest message received from the client.
303
304    Returns:
305      A tuple of HTTP status code and response data to send to the client.
306    """
307    policy = self.server.GetPolicies()
308
309    # Return the auth code from the config file if it's defined,
310    # else return a descriptive default value.
311    response = dm.DeviceManagementResponse()
312    response.service_api_access_response.auth_code = policy.get(
313        'robot_api_auth_code', 'policy_testserver.py-auth_code')
314    self.DumpMessage('Response', response)
315
316    return (200, response.SerializeToString())
317
318  def ProcessUnregister(self, msg):
319    """Handles a register request.
320
321    Checks for authorization, unregisters the device and constructs the
322    response.
323
324    Args:
325      msg: The DeviceUnregisterRequest message received from the client.
326
327    Returns:
328      A tuple of HTTP status code and response data to send to the client.
329    """
330    # Check the management token.
331    token, response = self.CheckToken()
332    if not token:
333      return response
334
335    # Unregister the device.
336    self.server.UnregisterDevice(token['device_token'])
337
338    # Prepare and send the response.
339    response = dm.DeviceManagementResponse()
340    response.unregister_response.CopyFrom(dm.DeviceUnregisterResponse())
341
342    self.DumpMessage('Response', response)
343
344    return (200, response.SerializeToString())
345
346  def ProcessPolicy(self, msg, request_type):
347    """Handles a policy request.
348
349    Checks for authorization, encodes the policy into protobuf representation
350    and constructs the response.
351
352    Args:
353      msg: The DevicePolicyRequest message received from the client.
354
355    Returns:
356      A tuple of HTTP status code and response data to send to the client.
357    """
358    token_info, error = self.CheckToken()
359    if not token_info:
360      return error
361
362    response = dm.DeviceManagementResponse()
363    for request in msg.request:
364      fetch_response = response.policy_response.response.add()
365      if (request.policy_type in
366             ('google/chrome/user',
367              'google/chromeos/user',
368              'google/chromeos/device',
369              'google/chromeos/publicaccount',
370              'google/chrome/extension')):
371        if request_type != 'policy':
372          fetch_response.error_code = 400
373          fetch_response.error_message = 'Invalid request type'
374        else:
375          self.ProcessCloudPolicy(request, token_info, fetch_response)
376      else:
377        fetch_response.error_code = 400
378        fetch_response.error_message = 'Invalid policy_type'
379
380    return (200, response.SerializeToString())
381
382  def ProcessAutoEnrollment(self, msg):
383    """Handles an auto-enrollment check request.
384
385    The reply depends on the value of the modulus:
386      1: replies with no new modulus and the sha256 hash of "0"
387      2: replies with a new modulus, 4.
388      4: replies with a new modulus, 2.
389      8: fails with error 400.
390      16: replies with a new modulus, 16.
391      32: replies with a new modulus, 1.
392      anything else: replies with no new modulus and an empty list of hashes
393
394    These allow the client to pick the testing scenario its wants to simulate.
395
396    Args:
397      msg: The DeviceAutoEnrollmentRequest message received from the client.
398
399    Returns:
400      A tuple of HTTP status code and response data to send to the client.
401    """
402    auto_enrollment_response = dm.DeviceAutoEnrollmentResponse()
403
404    if msg.modulus == 1:
405      auto_enrollment_response.hash.append(SHA256_0)
406    elif msg.modulus == 2:
407      auto_enrollment_response.expected_modulus = 4
408    elif msg.modulus == 4:
409      auto_enrollment_response.expected_modulus = 2
410    elif msg.modulus == 8:
411      return (400, 'Server error')
412    elif msg.modulus == 16:
413      auto_enrollment_response.expected_modulus = 16
414    elif msg.modulus == 32:
415      auto_enrollment_response.expected_modulus = 1
416
417    response = dm.DeviceManagementResponse()
418    response.auto_enrollment_response.CopyFrom(auto_enrollment_response)
419    return (200, response.SerializeToString())
420
421  def SetProtobufMessageField(self, group_message, field, field_value):
422    '''Sets a field in a protobuf message.
423
424    Args:
425      group_message: The protobuf message.
426      field: The field of the message to set, it should be a member of
427          group_message.DESCRIPTOR.fields.
428      field_value: The value to set.
429    '''
430    if field.label == field.LABEL_REPEATED:
431      assert type(field_value) == list
432      entries = group_message.__getattribute__(field.name)
433      if field.message_type is None:
434        for list_item in field_value:
435          entries.append(list_item)
436      else:
437        # This field is itself a protobuf.
438        sub_type = field.message_type
439        for sub_value in field_value:
440          assert type(sub_value) == dict
441          # Add a new sub-protobuf per list entry.
442          sub_message = entries.add()
443          # Now iterate over its fields and recursively add them.
444          for sub_field in sub_message.DESCRIPTOR.fields:
445            if sub_field.name in sub_value:
446              value = sub_value[sub_field.name]
447              self.SetProtobufMessageField(sub_message, sub_field, value)
448      return
449    elif field.type == field.TYPE_BOOL:
450      assert type(field_value) == bool
451    elif field.type == field.TYPE_STRING:
452      assert type(field_value) == str or type(field_value) == unicode
453    elif field.type == field.TYPE_INT64:
454      assert type(field_value) == int
455    elif (field.type == field.TYPE_MESSAGE and
456          field.message_type.name == 'StringList'):
457      assert type(field_value) == list
458      entries = group_message.__getattribute__(field.name).entries
459      for list_item in field_value:
460        entries.append(list_item)
461      return
462    else:
463      raise Exception('Unknown field type %s' % field.type)
464    group_message.__setattr__(field.name, field_value)
465
466  def GatherDevicePolicySettings(self, settings, policies):
467    '''Copies all the policies from a dictionary into a protobuf of type
468    CloudDeviceSettingsProto.
469
470    Args:
471      settings: The destination ChromeDeviceSettingsProto protobuf.
472      policies: The source dictionary containing policies in JSON format.
473    '''
474    for group in settings.DESCRIPTOR.fields:
475      # Create protobuf message for group.
476      group_message = eval('dp.' + group.message_type.name + '()')
477      # Indicates if at least one field was set in |group_message|.
478      got_fields = False
479      # Iterate over fields of the message and feed them from the
480      # policy config file.
481      for field in group_message.DESCRIPTOR.fields:
482        field_value = None
483        if field.name in policies:
484          got_fields = True
485          field_value = policies[field.name]
486          self.SetProtobufMessageField(group_message, field, field_value)
487      if got_fields:
488        settings.__getattribute__(group.name).CopyFrom(group_message)
489
490  def GatherUserPolicySettings(self, settings, policies):
491    '''Copies all the policies from a dictionary into a protobuf of type
492    CloudPolicySettings.
493
494    Args:
495      settings: The destination: a CloudPolicySettings protobuf.
496      policies: The source: a dictionary containing policies under keys
497          'recommended' and 'mandatory'.
498    '''
499    for field in settings.DESCRIPTOR.fields:
500      # |field| is the entry for a specific policy in the top-level
501      # CloudPolicySettings proto.
502
503      # Look for this policy's value in the mandatory or recommended dicts.
504      if field.name in policies.get('mandatory', {}):
505        mode = cp.PolicyOptions.MANDATORY
506        value = policies['mandatory'][field.name]
507      elif field.name in policies.get('recommended', {}):
508        mode = cp.PolicyOptions.RECOMMENDED
509        value = policies['recommended'][field.name]
510      else:
511        continue
512
513      # Create protobuf message for this policy.
514      policy_message = eval('cp.' + field.message_type.name + '()')
515      policy_message.policy_options.mode = mode
516      field_descriptor = policy_message.DESCRIPTOR.fields_by_name['value']
517      self.SetProtobufMessageField(policy_message, field_descriptor, value)
518      settings.__getattribute__(field.name).CopyFrom(policy_message)
519
520  def ProcessCloudPolicy(self, msg, token_info, response):
521    """Handles a cloud policy request. (New protocol for policy requests.)
522
523    Encodes the policy into protobuf representation, signs it and constructs
524    the response.
525
526    Args:
527      msg: The CloudPolicyRequest message received from the client.
528      token_info: the token extracted from the request.
529      response: A PolicyFetchResponse message that should be filled with the
530                response data.
531    """
532
533    if msg.machine_id:
534      self.server.UpdateMachineId(token_info['device_token'], msg.machine_id)
535
536    # Response is only given if the scope is specified in the config file.
537    # Normally 'google/chromeos/device', 'google/chromeos/user' and
538    # 'google/chromeos/publicaccount' should be accepted.
539    policy = self.server.GetPolicies()
540    policy_value = ''
541    policy_key = msg.policy_type
542    if msg.settings_entity_id:
543      policy_key += '/' + msg.settings_entity_id
544    if msg.policy_type in token_info['allowed_policy_types']:
545      if (msg.policy_type == 'google/chromeos/user' or
546          msg.policy_type == 'google/chrome/user' or
547          msg.policy_type == 'google/chromeos/publicaccount'):
548        settings = cp.CloudPolicySettings()
549        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
550        if payload is None:
551          self.GatherUserPolicySettings(settings, policy.get(policy_key, {}))
552          payload = settings.SerializeToString()
553      elif dp is not None and msg.policy_type == 'google/chromeos/device':
554        settings = dp.ChromeDeviceSettingsProto()
555        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
556        if payload is None:
557          self.GatherDevicePolicySettings(settings, policy.get(policy_key, {}))
558          payload = settings.SerializeToString()
559      elif msg.policy_type == 'google/chrome/extension':
560        settings = ep.ExternalPolicyData()
561        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
562        if payload is None:
563          payload = self.CreatePolicyForExternalPolicyData(policy_key)
564      else:
565        response.error_code = 400
566        response.error_message = 'Invalid policy type'
567        return
568    else:
569      response.error_code = 400
570      response.error_message = 'Request not allowed for the token used'
571      return
572
573    # Sign with 'current_key_index', defaulting to key 0.
574    signing_key = None
575    req_key = None
576    current_key_index = policy.get('current_key_index', 0)
577    nkeys = len(self.server.keys)
578    if (msg.signature_type == dm.PolicyFetchRequest.SHA1_RSA and
579        current_key_index in range(nkeys)):
580      signing_key = self.server.keys[current_key_index]
581      if msg.public_key_version in range(1, nkeys + 1):
582        # requested key exists, use for signing and rotate.
583        req_key = self.server.keys[msg.public_key_version - 1]['private_key']
584
585    # Fill the policy data protobuf.
586    policy_data = dm.PolicyData()
587    policy_data.policy_type = msg.policy_type
588    policy_data.timestamp = int(time.time() * 1000)
589    policy_data.request_token = token_info['device_token']
590    policy_data.policy_value = payload
591    policy_data.machine_name = token_info['machine_name']
592    policy_data.valid_serial_number_missing = (
593        token_info['machine_id'] in BAD_MACHINE_IDS)
594    policy_data.settings_entity_id = msg.settings_entity_id
595    policy_data.service_account_identity = policy.get(
596        'service_account_identity',
597        'policy_testserver.py-service_account_identity')
598
599    if signing_key:
600      policy_data.public_key_version = current_key_index + 1
601    if msg.policy_type == 'google/chromeos/publicaccount':
602      policy_data.username = msg.settings_entity_id
603    else:
604      # For regular user/device policy, there is no way for the testserver to
605      # know the user name belonging to the GAIA auth token we received (short
606      # of actually talking to GAIA). To address this, we read the username from
607      # the policy configuration dictionary, or use a default.
608      policy_data.username = policy.get('policy_user', 'user@example.com')
609    policy_data.device_id = token_info['device_id']
610    signed_data = policy_data.SerializeToString()
611
612    response.policy_data = signed_data
613    if signing_key:
614      response.policy_data_signature = (
615          signing_key['private_key'].hashAndSign(signed_data).tostring())
616      if msg.public_key_version != current_key_index + 1:
617        response.new_public_key = signing_key['public_key']
618        if req_key:
619          response.new_public_key_signature = (
620              req_key.hashAndSign(response.new_public_key).tostring())
621
622    self.DumpMessage('Response', response)
623
624    return (200, response.SerializeToString())
625
626  def CheckToken(self):
627    """Helper for checking whether the client supplied a valid DM token.
628
629    Extracts the token from the request and passed to the server in order to
630    look up the client.
631
632    Returns:
633      A pair of token information record and error response. If the first
634      element is None, then the second contains an error code to send back to
635      the client. Otherwise the first element is the same structure that is
636      returned by LookupToken().
637    """
638    error = 500
639    dmtoken = None
640    request_device_id = self.GetUniqueParam('deviceid')
641    match = re.match('GoogleDMToken token=(\\w+)',
642                     self.headers.getheader('Authorization', ''))
643    if match:
644      dmtoken = match.group(1)
645    if not dmtoken:
646      error = 401
647    else:
648      token_info = self.server.LookupToken(dmtoken)
649      if (not token_info or
650          not request_device_id or
651          token_info['device_id'] != request_device_id):
652        error = 410
653      else:
654        return (token_info, None)
655
656    logging.debug('Token check failed with error %d' % error)
657
658    return (None, (error, 'Server error %d' % error))
659
660  def DumpMessage(self, label, msg):
661    """Helper for logging an ASCII dump of a protobuf message."""
662    logging.debug('%s\n%s' % (label, str(msg)))
663
664
665class PolicyTestServer(testserver_base.ClientRestrictingServerMixIn,
666                       testserver_base.BrokenPipeHandlerMixIn,
667                       testserver_base.StoppableHTTPServer):
668  """Handles requests and keeps global service state."""
669
670  def __init__(self, server_address, data_dir, policy_path, client_state_file,
671               private_key_paths):
672    """Initializes the server.
673
674    Args:
675      server_address: Server host and port.
676      policy_path: Names the file to read JSON-formatted policy from.
677      private_key_paths: List of paths to read private keys from.
678    """
679    testserver_base.StoppableHTTPServer.__init__(self, server_address,
680                                                 PolicyRequestHandler)
681    self._registered_tokens = {}
682    self.data_dir = data_dir
683    self.policy_path = policy_path
684    self.client_state_file = client_state_file
685
686    self.keys = []
687    if private_key_paths:
688      # Load specified keys from the filesystem.
689      for key_path in private_key_paths:
690        try:
691          key_str = open(key_path).read()
692        except IOError:
693          print 'Failed to load private key from %s' % key_path
694          continue
695
696        try:
697          key = tlslite.api.parsePEMKey(key_str, private=True)
698        except SyntaxError:
699          key = tlslite.utils.Python_RSAKey.Python_RSAKey._parsePKCS8(
700              tlslite.utils.cryptomath.stringToBytes(key_str))
701
702        assert key is not None
703        self.keys.append({ 'private_key' : key })
704    else:
705      # Generate 2 private keys if none were passed from the command line.
706      for i in range(2):
707        key = tlslite.api.generateRSAKey(512)
708        assert key is not None
709        self.keys.append({ 'private_key' : key })
710
711    # Derive the public keys from the private keys.
712    for entry in self.keys:
713      key = entry['private_key']
714
715      algorithm = asn1der.Sequence(
716          [ asn1der.Data(asn1der.OBJECT_IDENTIFIER, PKCS1_RSA_OID),
717            asn1der.Data(asn1der.NULL, '') ])
718      rsa_pubkey = asn1der.Sequence([ asn1der.Integer(key.n),
719                                      asn1der.Integer(key.e) ])
720      pubkey = asn1der.Sequence([ algorithm, asn1der.Bitstring(rsa_pubkey) ])
721      entry['public_key'] = pubkey
722
723    # Load client state.
724    if self.client_state_file is not None:
725      try:
726        file_contents = open(self.client_state_file).read()
727        self._registered_tokens = json.loads(file_contents)
728      except IOError:
729        pass
730
731  def GetPolicies(self):
732    """Returns the policies to be used, reloaded form the backend file every
733       time this is called.
734    """
735    policy = {}
736    if json is None:
737      print 'No JSON module, cannot parse policy information'
738    else :
739      try:
740        policy = json.loads(open(self.policy_path).read())
741      except IOError:
742        print 'Failed to load policy from %s' % self.policy_path
743    return policy
744
745  def RegisterDevice(self, device_id, machine_id, type):
746    """Registers a device or user and generates a DM token for it.
747
748    Args:
749      device_id: The device identifier provided by the client.
750
751    Returns:
752      The newly generated device token for the device.
753    """
754    dmtoken_chars = []
755    while len(dmtoken_chars) < 32:
756      dmtoken_chars.append(random.choice('0123456789abcdef'))
757    dmtoken = ''.join(dmtoken_chars)
758    allowed_policy_types = {
759      dm.DeviceRegisterRequest.BROWSER: ['google/chrome/user'],
760      dm.DeviceRegisterRequest.USER: [
761          'google/chromeos/user',
762          'google/chrome/extension'
763      ],
764      dm.DeviceRegisterRequest.DEVICE: [
765          'google/chromeos/device',
766          'google/chromeos/publicaccount'
767      ],
768      dm.DeviceRegisterRequest.TT: ['google/chromeos/user',
769                                    'google/chrome/user'],
770    }
771    if machine_id in KIOSK_MACHINE_IDS:
772      enrollment_mode = dm.DeviceRegisterResponse.RETAIL
773    else:
774      enrollment_mode = dm.DeviceRegisterResponse.ENTERPRISE
775    self._registered_tokens[dmtoken] = {
776      'device_id': device_id,
777      'device_token': dmtoken,
778      'allowed_policy_types': allowed_policy_types[type],
779      'machine_name': 'chromeos-' + machine_id,
780      'machine_id': machine_id,
781      'enrollment_mode': enrollment_mode,
782    }
783    self.WriteClientState()
784    return self._registered_tokens[dmtoken]
785
786  def UpdateMachineId(self, dmtoken, machine_id):
787    """Updates the machine identifier for a registered device.
788
789    Args:
790      dmtoken: The device management token provided by the client.
791      machine_id: Updated hardware identifier value.
792    """
793    if dmtoken in self._registered_tokens:
794      self._registered_tokens[dmtoken]['machine_id'] = machine_id
795      self.WriteClientState()
796
797  def LookupToken(self, dmtoken):
798    """Looks up a device or a user by DM token.
799
800    Args:
801      dmtoken: The device management token provided by the client.
802
803    Returns:
804      A dictionary with information about a device or user that is registered by
805      dmtoken, or None if the token is not found.
806    """
807    return self._registered_tokens.get(dmtoken, None)
808
809  def UnregisterDevice(self, dmtoken):
810    """Unregisters a device identified by the given DM token.
811
812    Args:
813      dmtoken: The device management token provided by the client.
814    """
815    if dmtoken in self._registered_tokens.keys():
816      del self._registered_tokens[dmtoken]
817      self.WriteClientState()
818
819  def WriteClientState(self):
820    """Writes the client state back to the file."""
821    if self.client_state_file is not None:
822      json_data = json.dumps(self._registered_tokens)
823      open(self.client_state_file, 'w').write(json_data)
824
825  def GetBaseFilename(self, policy_selector):
826    """Returns the base filename for the given policy_selector.
827
828    Args:
829      policy_selector: the policy type and settings entity id, joined by '/'.
830
831    Returns:
832      The filename corresponding to the policy_selector, without a file
833      extension.
834    """
835    sanitized_policy_selector = re.sub('[^A-Za-z0-9.@-]', '_', policy_selector)
836    return os.path.join(self.data_dir or '',
837                        'policy_%s' % sanitized_policy_selector)
838
839  def ReadPolicyFromDataDir(self, policy_selector, proto_message):
840    """Tries to read policy payload from a file in the data directory.
841
842    First checks for a binary rendition of the policy protobuf in
843    <data_dir>/policy_<sanitized_policy_selector>.bin. If that exists, returns
844    it. If that file doesn't exist, tries
845    <data_dir>/policy_<sanitized_policy_selector>.txt and decodes that as a
846    protobuf using proto_message. If that fails as well, returns None.
847
848    Args:
849      policy_selector: Selects which policy to read.
850      proto_message: Optional protobuf message object used for decoding the
851          proto text format.
852
853    Returns:
854      The binary payload message, or None if not found.
855    """
856    base_filename = self.GetBaseFilename(policy_selector)
857
858    # Try the binary payload file first.
859    try:
860      return open(base_filename + '.bin').read()
861    except IOError:
862      pass
863
864    # If that fails, try the text version instead.
865    if proto_message is None:
866      return None
867
868    try:
869      text = open(base_filename + '.txt').read()
870      google.protobuf.text_format.Merge(text, proto_message)
871      return proto_message.SerializeToString()
872    except IOError:
873      return None
874    except google.protobuf.text_format.ParseError:
875      return None
876
877  def ReadPolicyDataFromDataDir(self, policy_selector):
878    """Returns the external policy data for |policy_selector| if found.
879
880    Args:
881      policy_selector: Selects which policy to read.
882
883    Returns:
884      The data for the corresponding policy type and entity id, if found.
885    """
886    base_filename = self.GetBaseFilename(policy_selector)
887    try:
888      return open(base_filename + '.data').read()
889    except IOError:
890      return None
891
892
893class PolicyServerRunner(testserver_base.TestServerRunner):
894
895  def __init__(self):
896    super(PolicyServerRunner, self).__init__()
897
898  def create_server(self, server_data):
899    data_dir = self.options.data_dir or ''
900    config_file = (self.options.config_file or
901                   os.path.join(data_dir, 'device_management'))
902    server = PolicyTestServer((self.options.host, self.options.port),
903                              data_dir, config_file,
904                              self.options.client_state_file,
905                              self.options.policy_keys)
906    server_data['port'] = server.server_port
907    return server
908
909  def add_options(self):
910    testserver_base.TestServerRunner.add_options(self)
911    self.option_parser.add_option('--client-state', dest='client_state_file',
912                                  help='File that client state should be '
913                                  'persisted to. This allows the server to be '
914                                  'seeded by a list of pre-registered clients '
915                                  'and restarts without abandoning registered '
916                                  'clients.')
917    self.option_parser.add_option('--policy-key', action='append',
918                                  dest='policy_keys',
919                                  help='Specify a path to a PEM-encoded '
920                                  'private key to use for policy signing. May '
921                                  'be specified multiple times in order to '
922                                  'load multipe keys into the server. If the '
923                                  'server has multiple keys, it will rotate '
924                                  'through them in at each request in a '
925                                  'round-robin fashion. The server will '
926                                  'generate a random key if none is specified '
927                                  'on the command line.')
928    self.option_parser.add_option('--log-level', dest='log_level',
929                                  default='WARN',
930                                  help='Log level threshold to use.')
931    self.option_parser.add_option('--config-file', dest='config_file',
932                                  help='Specify a configuration file to use '
933                                  'instead of the default '
934                                  '<data_dir>/device_management')
935
936  def run_server(self):
937    logger = logging.getLogger()
938    logger.setLevel(getattr(logging, str(self.options.log_level).upper()))
939    if (self.options.log_to_console):
940      logger.addHandler(logging.StreamHandler())
941    if (self.options.log_file):
942      logger.addHandler(logging.FileHandler(self.options.log_file))
943
944    testserver_base.TestServerRunner.run_server(self)
945
946
947if __name__ == '__main__':
948  sys.exit(PolicyServerRunner().main())
949