1# Copyright (C) 2005 Martin v. Löwis
2# Licensed to PSF under a Contributor Agreement.
3from _msi import *
4import fnmatch
5import os
6import re
7import string
8import sys
9
10AMD64 = "AMD64" in sys.version
11Itanium = "Itanium" in sys.version
12Win64 = AMD64 or Itanium
13
14# Partially taken from Wine
15datasizemask=      0x00ff
16type_valid=        0x0100
17type_localizable=  0x0200
18
19typemask=          0x0c00
20type_long=         0x0000
21type_short=        0x0400
22type_string=       0x0c00
23type_binary=       0x0800
24
25type_nullable=     0x1000
26type_key=          0x2000
27# XXX temporary, localizable?
28knownbits = datasizemask | type_valid | type_localizable | \
29            typemask | type_nullable | type_key
30
31class Table:
32    def __init__(self, name):
33        self.name = name
34        self.fields = []
35
36    def add_field(self, index, name, type):
37        self.fields.append((index,name,type))
38
39    def sql(self):
40        fields = []
41        keys = []
42        self.fields.sort()
43        fields = [None]*len(self.fields)
44        for index, name, type in self.fields:
45            index -= 1
46            unk = type & ~knownbits
47            if unk:
48                print("%s.%s unknown bits %x" % (self.name, name, unk))
49            size = type & datasizemask
50            dtype = type & typemask
51            if dtype == type_string:
52                if size:
53                    tname="CHAR(%d)" % size
54                else:
55                    tname="CHAR"
56            elif dtype == type_short:
57                assert size==2
58                tname = "SHORT"
59            elif dtype == type_long:
60                assert size==4
61                tname="LONG"
62            elif dtype == type_binary:
63                assert size==0
64                tname="OBJECT"
65            else:
66                tname="unknown"
67                print("%s.%sunknown integer type %d" % (self.name, name, size))
68            if type & type_nullable:
69                flags = ""
70            else:
71                flags = " NOT NULL"
72            if type & type_localizable:
73                flags += " LOCALIZABLE"
74            fields[index] = "`%s` %s%s" % (name, tname, flags)
75            if type & type_key:
76                keys.append("`%s`" % name)
77        fields = ", ".join(fields)
78        keys = ", ".join(keys)
79        return "CREATE TABLE %s (%s PRIMARY KEY %s)" % (self.name, fields, keys)
80
81    def create(self, db):
82        v = db.OpenView(self.sql())
83        v.Execute(None)
84        v.Close()
85
86class _Unspecified:pass
87def change_sequence(seq, action, seqno=_Unspecified, cond = _Unspecified):
88    "Change the sequence number of an action in a sequence list"
89    for i in range(len(seq)):
90        if seq[i][0] == action:
91            if cond is _Unspecified:
92                cond = seq[i][1]
93            if seqno is _Unspecified:
94                seqno = seq[i][2]
95            seq[i] = (action, cond, seqno)
96            return
97    raise ValueError("Action not found in sequence")
98
99def add_data(db, table, values):
100    v = db.OpenView("SELECT * FROM `%s`" % table)
101    count = v.GetColumnInfo(MSICOLINFO_NAMES).GetFieldCount()
102    r = CreateRecord(count)
103    for value in values:
104        assert len(value) == count, value
105        for i in range(count):
106            field = value[i]
107            if isinstance(field, int):
108                r.SetInteger(i+1,field)
109            elif isinstance(field, str):
110                r.SetString(i+1,field)
111            elif field is None:
112                pass
113            elif isinstance(field, Binary):
114                r.SetStream(i+1, field.name)
115            else:
116                raise TypeError("Unsupported type %s" % field.__class__.__name__)
117        try:
118            v.Modify(MSIMODIFY_INSERT, r)
119        except Exception as e:
120            raise MSIError("Could not insert "+repr(values)+" into "+table)
121
122        r.ClearData()
123    v.Close()
124
125
126def add_stream(db, name, path):
127    v = db.OpenView("INSERT INTO _Streams (Name, Data) VALUES ('%s', ?)" % name)
128    r = CreateRecord(1)
129    r.SetStream(1, path)
130    v.Execute(r)
131    v.Close()
132
133def init_database(name, schema,
134                  ProductName, ProductCode, ProductVersion,
135                  Manufacturer):
136    try:
137        os.unlink(name)
138    except OSError:
139        pass
140    ProductCode = ProductCode.upper()
141    # Create the database
142    db = OpenDatabase(name, MSIDBOPEN_CREATE)
143    # Create the tables
144    for t in schema.tables:
145        t.create(db)
146    # Fill the validation table
147    add_data(db, "_Validation", schema._Validation_records)
148    # Initialize the summary information, allowing atmost 20 properties
149    si = db.GetSummaryInformation(20)
150    si.SetProperty(PID_TITLE, "Installation Database")
151    si.SetProperty(PID_SUBJECT, ProductName)
152    si.SetProperty(PID_AUTHOR, Manufacturer)
153    if Itanium:
154        si.SetProperty(PID_TEMPLATE, "Intel64;1033")
155    elif AMD64:
156        si.SetProperty(PID_TEMPLATE, "x64;1033")
157    else:
158        si.SetProperty(PID_TEMPLATE, "Intel;1033")
159    si.SetProperty(PID_REVNUMBER, gen_uuid())
160    si.SetProperty(PID_WORDCOUNT, 2) # long file names, compressed, original media
161    si.SetProperty(PID_PAGECOUNT, 200)
162    si.SetProperty(PID_APPNAME, "Python MSI Library")
163    # XXX more properties
164    si.Persist()
165    add_data(db, "Property", [
166        ("ProductName", ProductName),
167        ("ProductCode", ProductCode),
168        ("ProductVersion", ProductVersion),
169        ("Manufacturer", Manufacturer),
170        ("ProductLanguage", "1033")])
171    db.Commit()
172    return db
173
174def add_tables(db, module):
175    for table in module.tables:
176        add_data(db, table, getattr(module, table))
177
178def make_id(str):
179    identifier_chars = string.ascii_letters + string.digits + "._"
180    str = "".join([c if c in identifier_chars else "_" for c in str])
181    if str[0] in (string.digits + "."):
182        str = "_" + str
183    assert re.match("^[A-Za-z_][A-Za-z0-9_.]*$", str), "FILE"+str
184    return str
185
186def gen_uuid():
187    return "{"+UuidCreate().upper()+"}"
188
189class CAB:
190    def __init__(self, name):
191        self.name = name
192        self.files = []
193        self.filenames = set()
194        self.index = 0
195
196    def gen_id(self, file):
197        logical = _logical = make_id(file)
198        pos = 1
199        while logical in self.filenames:
200            logical = "%s.%d" % (_logical, pos)
201            pos += 1
202        self.filenames.add(logical)
203        return logical
204
205    def append(self, full, file, logical):
206        if os.path.isdir(full):
207            return
208        if not logical:
209            logical = self.gen_id(file)
210        self.index += 1
211        self.files.append((full, logical))
212        return self.index, logical
213
214    def commit(self, db):
215        from tempfile import mktemp
216        filename = mktemp()
217        FCICreate(filename, self.files)
218        add_data(db, "Media",
219                [(1, self.index, None, "#"+self.name, None, None)])
220        add_stream(db, self.name, filename)
221        os.unlink(filename)
222        db.Commit()
223
224_directories = set()
225class Directory:
226    def __init__(self, db, cab, basedir, physical, _logical, default, componentflags=None):
227        """Create a new directory in the Directory table. There is a current component
228        at each point in time for the directory, which is either explicitly created
229        through start_component, or implicitly when files are added for the first
230        time. Files are added into the current component, and into the cab file.
231        To create a directory, a base directory object needs to be specified (can be
232        None), the path to the physical directory, and a logical directory name.
233        Default specifies the DefaultDir slot in the directory table. componentflags
234        specifies the default flags that new components get."""
235        index = 1
236        _logical = make_id(_logical)
237        logical = _logical
238        while logical in _directories:
239            logical = "%s%d" % (_logical, index)
240            index += 1
241        _directories.add(logical)
242        self.db = db
243        self.cab = cab
244        self.basedir = basedir
245        self.physical = physical
246        self.logical = logical
247        self.component = None
248        self.short_names = set()
249        self.ids = set()
250        self.keyfiles = {}
251        self.componentflags = componentflags
252        if basedir:
253            self.absolute = os.path.join(basedir.absolute, physical)
254            blogical = basedir.logical
255        else:
256            self.absolute = physical
257            blogical = None
258        add_data(db, "Directory", [(logical, blogical, default)])
259
260    def start_component(self, component = None, feature = None, flags = None, keyfile = None, uuid=None):
261        """Add an entry to the Component table, and make this component the current for this
262        directory. If no component name is given, the directory name is used. If no feature
263        is given, the current feature is used. If no flags are given, the directory's default
264        flags are used. If no keyfile is given, the KeyPath is left null in the Component
265        table."""
266        if flags is None:
267            flags = self.componentflags
268        if uuid is None:
269            uuid = gen_uuid()
270        else:
271            uuid = uuid.upper()
272        if component is None:
273            component = self.logical
274        self.component = component
275        if Win64:
276            flags |= 256
277        if keyfile:
278            keyid = self.cab.gen_id(self.absolute, keyfile)
279            self.keyfiles[keyfile] = keyid
280        else:
281            keyid = None
282        add_data(self.db, "Component",
283                        [(component, uuid, self.logical, flags, None, keyid)])
284        if feature is None:
285            feature = current_feature
286        add_data(self.db, "FeatureComponents",
287                        [(feature.id, component)])
288
289    def make_short(self, file):
290        oldfile = file
291        file = file.replace('+', '_')
292        file = ''.join(c for c in file if not c in r' "/\[]:;=,')
293        parts = file.split(".")
294        if len(parts) > 1:
295            prefix = "".join(parts[:-1]).upper()
296            suffix = parts[-1].upper()
297            if not prefix:
298                prefix = suffix
299                suffix = None
300        else:
301            prefix = file.upper()
302            suffix = None
303        if len(parts) < 3 and len(prefix) <= 8 and file == oldfile and (
304                                                not suffix or len(suffix) <= 3):
305            if suffix:
306                file = prefix+"."+suffix
307            else:
308                file = prefix
309        else:
310            file = None
311        if file is None or file in self.short_names:
312            prefix = prefix[:6]
313            if suffix:
314                suffix = suffix[:3]
315            pos = 1
316            while 1:
317                if suffix:
318                    file = "%s~%d.%s" % (prefix, pos, suffix)
319                else:
320                    file = "%s~%d" % (prefix, pos)
321                if file not in self.short_names: break
322                pos += 1
323                assert pos < 10000
324                if pos in (10, 100, 1000):
325                    prefix = prefix[:-1]
326        self.short_names.add(file)
327        assert not re.search(r'[\?|><:/*"+,;=\[\]]', file) # restrictions on short names
328        return file
329
330    def add_file(self, file, src=None, version=None, language=None):
331        """Add a file to the current component of the directory, starting a new one
332        if there is no current component. By default, the file name in the source
333        and the file table will be identical. If the src file is specified, it is
334        interpreted relative to the current directory. Optionally, a version and a
335        language can be specified for the entry in the File table."""
336        if not self.component:
337            self.start_component(self.logical, current_feature, 0)
338        if not src:
339            # Allow relative paths for file if src is not specified
340            src = file
341            file = os.path.basename(file)
342        absolute = os.path.join(self.absolute, src)
343        assert not re.search(r'[\?|><:/*]"', file) # restrictions on long names
344        if file in self.keyfiles:
345            logical = self.keyfiles[file]
346        else:
347            logical = None
348        sequence, logical = self.cab.append(absolute, file, logical)
349        assert logical not in self.ids
350        self.ids.add(logical)
351        short = self.make_short(file)
352        full = "%s|%s" % (short, file)
353        filesize = os.stat(absolute).st_size
354        # constants.msidbFileAttributesVital
355        # Compressed omitted, since it is the database default
356        # could add r/o, system, hidden
357        attributes = 512
358        add_data(self.db, "File",
359                        [(logical, self.component, full, filesize, version,
360                         language, attributes, sequence)])
361        #if not version:
362        #    # Add hash if the file is not versioned
363        #    filehash = FileHash(absolute, 0)
364        #    add_data(self.db, "MsiFileHash",
365        #             [(logical, 0, filehash.IntegerData(1),
366        #               filehash.IntegerData(2), filehash.IntegerData(3),
367        #               filehash.IntegerData(4))])
368        # Automatically remove .pyc files on uninstall (2)
369        # XXX: adding so many RemoveFile entries makes installer unbelievably
370        # slow. So instead, we have to use wildcard remove entries
371        if file.endswith(".py"):
372            add_data(self.db, "RemoveFile",
373                      [(logical+"c", self.component, "%sC|%sc" % (short, file),
374                        self.logical, 2),
375                       (logical+"o", self.component, "%sO|%so" % (short, file),
376                        self.logical, 2)])
377        return logical
378
379    def glob(self, pattern, exclude = None):
380        """Add a list of files to the current component as specified in the
381        glob pattern. Individual files can be excluded in the exclude list."""
382        try:
383            files = os.listdir(self.absolute)
384        except OSError:
385            return []
386        if pattern[:1] != '.':
387            files = (f for f in files if f[0] != '.')
388        files = fnmatch.filter(files, pattern)
389        for f in files:
390            if exclude and f in exclude: continue
391            self.add_file(f)
392        return files
393
394    def remove_pyc(self):
395        "Remove .pyc files on uninstall"
396        add_data(self.db, "RemoveFile",
397                 [(self.component+"c", self.component, "*.pyc", self.logical, 2)])
398
399class Binary:
400    def __init__(self, fname):
401        self.name = fname
402    def __repr__(self):
403        return 'msilib.Binary(os.path.join(dirname,"%s"))' % self.name
404
405class Feature:
406    def __init__(self, db, id, title, desc, display, level = 1,
407                 parent=None, directory = None, attributes=0):
408        self.id = id
409        if parent:
410            parent = parent.id
411        add_data(db, "Feature",
412                        [(id, parent, title, desc, display,
413                          level, directory, attributes)])
414    def set_current(self):
415        global current_feature
416        current_feature = self
417
418class Control:
419    def __init__(self, dlg, name):
420        self.dlg = dlg
421        self.name = name
422
423    def event(self, event, argument, condition = "1", ordering = None):
424        add_data(self.dlg.db, "ControlEvent",
425                 [(self.dlg.name, self.name, event, argument,
426                   condition, ordering)])
427
428    def mapping(self, event, attribute):
429        add_data(self.dlg.db, "EventMapping",
430                 [(self.dlg.name, self.name, event, attribute)])
431
432    def condition(self, action, condition):
433        add_data(self.dlg.db, "ControlCondition",
434                 [(self.dlg.name, self.name, action, condition)])
435
436class RadioButtonGroup(Control):
437    def __init__(self, dlg, name, property):
438        self.dlg = dlg
439        self.name = name
440        self.property = property
441        self.index = 1
442
443    def add(self, name, x, y, w, h, text, value = None):
444        if value is None:
445            value = name
446        add_data(self.dlg.db, "RadioButton",
447                 [(self.property, self.index, value,
448                   x, y, w, h, text, None)])
449        self.index += 1
450
451class Dialog:
452    def __init__(self, db, name, x, y, w, h, attr, title, first, default, cancel):
453        self.db = db
454        self.name = name
455        self.x, self.y, self.w, self.h = x,y,w,h
456        add_data(db, "Dialog", [(name, x,y,w,h,attr,title,first,default,cancel)])
457
458    def control(self, name, type, x, y, w, h, attr, prop, text, next, help):
459        add_data(self.db, "Control",
460                 [(self.name, name, type, x, y, w, h, attr, prop, text, next, help)])
461        return Control(self, name)
462
463    def text(self, name, x, y, w, h, attr, text):
464        return self.control(name, "Text", x, y, w, h, attr, None,
465                     text, None, None)
466
467    def bitmap(self, name, x, y, w, h, text):
468        return self.control(name, "Bitmap", x, y, w, h, 1, None, text, None, None)
469
470    def line(self, name, x, y, w, h):
471        return self.control(name, "Line", x, y, w, h, 1, None, None, None, None)
472
473    def pushbutton(self, name, x, y, w, h, attr, text, next):
474        return self.control(name, "PushButton", x, y, w, h, attr, None, text, next, None)
475
476    def radiogroup(self, name, x, y, w, h, attr, prop, text, next):
477        add_data(self.db, "Control",
478                 [(self.name, name, "RadioButtonGroup",
479                   x, y, w, h, attr, prop, text, next, None)])
480        return RadioButtonGroup(self, name, prop)
481
482    def checkbox(self, name, x, y, w, h, attr, prop, text, next):
483        return self.control(name, "CheckBox", x, y, w, h, attr, prop, text, next, None)
484