]> git.tdb.fi Git - libs/gl.git/blob - scripts/extgen.py
f5df83f21588ab0cbb71f94d8a94a8d05023af7a
[libs/gl.git] / scripts / extgen.py
1 #!/usr/bin/python
2
3 import sys
4 import os
5 import xml.dom
6 import xml.dom.minidom
7 import itertools
8 import re
9
10 class Version:
11         def __init__(self, *args):
12                 if len(args)==0:
13                         self.major = 0
14                         self.minor = 0
15                 elif len(args)==2:
16                         self.major = args[0]
17                         self.minor = args[1]
18                 else:
19                         raise TypeError, "__init__() takes zero or two arguments (%d given)"%len(args)
20
21         def __str__(self):
22                 return "%d.%d"%(self.major, self.minor)
23
24         def __repr__(self):
25                 return "Version(%d, %d)"%(self.major, self.minor)
26
27         def as_define(self):
28                 return "VERSION_%d_%d"%(self.major, self.minor)
29
30         def __lt__(self, other):
31                 if other is None:
32                         return False
33
34                 if self.major!=other.major:
35                         return self.major<other.major
36                 return self.minor<other.minor
37
38         def __gt__(self, other):
39                 if other is None:
40                         return True
41
42                 if self.major!=other.major:
43                         return self.major>other.major
44                 return self.minor>other.minor
45
46
47 class Thing:
48         FUNCTION = 1
49         ENUM = 2
50
51         class ApiSupport:
52                 def __init__(self):
53                         self.core_version = None
54                         self.deprecated_version = None
55                         self.extensions = []
56                         self.sources = []
57
58         def __init__(self, name, kind):
59                 self.name = name
60                 self.kind = kind
61                 self.aliases = []
62                 self.api_support = {}
63
64         def get_or_create_api_support(self, api):
65                 supp = self.api_support.get(api)
66                 if not supp:
67                         supp = Thing.ApiSupport()
68                         self.api_support[api] = supp
69                 return supp
70
71
72 class Function(Thing):
73         def __init__(self, name):
74                 Thing.__init__(self, name, Thing.FUNCTION)
75                 self.return_type = "void"
76                 self.params = []
77
78
79 r_bitmask = re.compile("_BIT[0-9]*(_|$)")
80
81 class Enum(Thing):
82         def __init__(self, name):
83                 Thing.__init__(self, name, Thing.ENUM)
84                 self.value = 0
85                 self.bitmask = bool(r_bitmask.search(self.name))
86
87
88 class Extension:
89         def __init__(self, name, api):
90                 self.name = name
91                 underscore = name.find('_')
92                 self.ext_type = name[0:underscore]
93                 self.base_name = name[underscore+1:]
94                 self.api = api
95                 self.things = {}
96                 self.preference = 0
97                 if self.ext_type=="EXT":
98                         self.preference = 1
99                 elif self.ext_type=="ARB" or self.ext_type=="OES":
100                         self.preference = 2
101                 self.backport = False
102
103
104 class Api:
105         def __init__(self, name):
106                 self.name = name
107                 self.latest_version = None
108                 self.core_things = {}
109                 self.extensions = {}
110
111
112 def get_nested_elements(elem, path):
113         childElements = [c for c in elem.childNodes if c.nodeType==xml.dom.Node.ELEMENT_NODE]
114         if '/' in path:
115                 head, tail = path.split('/', 1)
116                 result = []
117                 for c in childElements:
118                         if c.tagName==head:
119                                 result += get_nested_elements(c, tail)
120                 return result
121         else:
122                 return [c for c in childElements if c.tagName==path]
123
124 def get_first_child(elem, tag):
125         for c in elem.childNodes:
126                 if c.nodeType==xml.dom.Node.ELEMENT_NODE and c.tagName==tag:
127                         return c
128         return None
129
130 def get_text_contents(node):
131         result = ""
132         for c in node.childNodes:
133                 if c.nodeType==xml.dom.Node.TEXT_NODE or c.nodeType==xml.dom.Node.CDATA_SECTION_NODE:
134                         result += c.data
135                 else:
136                         result += get_text_contents(c)
137         return result
138
139 def get_or_create(map, name, type, *args):
140         obj = map.get(name)
141         if not obj:
142                 obj = type(name, *args)
143                 map[name] = obj
144         return obj
145
146
147 class GlXmlParser:
148         def __init__(self, host_api_name, target_ext_name):
149                 self.host_api_name = host_api_name
150                 self.target_ext_name = target_ext_name
151                 self.apis = {}
152                 self.things = {}
153
154         def parse_command(self, cmd):
155                 proto = get_first_child(cmd, "proto")
156                 name = get_text_contents(get_first_child(proto, "name"))
157                 func = get_or_create(self.things, name, Function)
158
159                 aliases = get_nested_elements(cmd, "alias")
160                 func.aliases = [a.getAttribute("name") for a in aliases]
161
162                 ptype = get_first_child(proto, "ptype")
163                 if ptype:
164                         func.return_type = get_text_contents(ptype)
165                 else:
166                         for c in proto.childNodes:
167                                 if c.nodeType==xml.dom.Node.TEXT_NODE and c.data.strip():
168                                         func.return_type = c.data.strip()
169                                         break
170
171                 params = get_nested_elements(cmd, "param")
172                 func.params = map(get_text_contents, params)
173
174         def parse_enum(self, en):
175                 name = en.getAttribute("name")
176                 enum = get_or_create(self.things, name, Enum)
177
178                 enum.value = int(en.getAttribute("value"), 16)
179
180         def parse_feature(self, feat):
181                 api_name = feat.getAttribute("api")
182                 if not api_name:
183                         api_name = self.host_api_name
184                 api = get_or_create(self.apis, api_name, Api)
185
186                 version = feat.getAttribute("number")
187                 if version:
188                         version = Version(*map(int, version.split('.')))
189                 else:
190                         version = None
191
192                 requires = get_nested_elements(feat, "require")
193                 for req in requires:
194                         commands = get_nested_elements(req, "command")
195                         enums = get_nested_elements(req, "enum")
196                         for t in itertools.chain(commands, enums):
197                                 name = t.getAttribute("name")
198                                 thing = self.things.get(name)
199                                 if thing:
200                                         supp = thing.get_or_create_api_support(api.name)
201                                         if not supp.core_version or version<supp.core_version:
202                                                 supp.core_version = version
203                                         api.core_things[thing.name] = thing
204
205                 removes = get_nested_elements(feat, "remove")
206                 for rem in removes:
207                         profile = rem.getAttribute("profile")
208                         commands = get_nested_elements(rem, "command")
209                         enums = get_nested_elements(rem, "enum")
210
211                         for t in itertools.chain(commands, enums):
212                                 name = t.getAttribute("name")
213                                 thing = self.things.get(name)
214                                 if thing:
215                                         if profile!="core":
216                                                 if thing.name in api.core_things:
217                                                         del api.core_things[thing.name]
218                                                 for s in thing.api_support.itervalues():
219                                                         for e in s.extensions:
220                                                                 del e.things[thing.name]
221                                         else:
222                                                 supp = thing.get_or_create_api_support(api.name)
223                                                 supp.deprecated_version = version
224
225         def parse_extension(self, ext):
226                 ext_things_by_api = {}
227                 requires = get_nested_elements(ext, "require")
228                 for req in requires:
229                         api = req.getAttribute("api")
230                         ext_things = ext_things_by_api.setdefault(api, [])
231
232                         commands = get_nested_elements(req, "command")
233                         enums = get_nested_elements(req, "enum")
234                         for t in itertools.chain(commands, enums):
235                                 name = t.getAttribute("name")
236                                 thing = self.things.get(name)
237                                 if thing:
238                                         ext_things.append(thing)
239
240                 ext_name = ext.getAttribute("name")
241                 if ext_name.startswith("GL_"):
242                         ext_name = ext_name[3:]
243
244                 common_things = ext_things_by_api.get("", [])
245                 supported = ext.getAttribute("supported").split('|')
246                 for s in supported:
247                         api = self.apis.get(s)
248                         if not api:
249                                 continue
250
251                         ext = get_or_create(api.extensions, ext_name, Extension, api)
252                         api_things = ext_things_by_api.get(s, [])
253                         for t in itertools.chain(common_things, api_things):
254                                 ext.things[t.name] = t
255                                 t.get_or_create_api_support(api.name).extensions.append(ext)
256
257         def parse_file(self, fn):
258                 doc = xml.dom.minidom.parse(fn)
259                 root = doc.documentElement
260
261                 commands = get_nested_elements(root, "commands/command")
262                 for cmd in commands:
263                         self.parse_command(cmd)
264
265                 enums = get_nested_elements(root, "enums/enum")
266                 for en in enums:
267                         self.parse_enum(en)
268
269                 features = get_nested_elements(root, "feature")
270                 for feat in features:
271                         self.parse_feature(feat)
272
273                 extensions = get_nested_elements(root, "extensions/extension")
274                 for ext in extensions:
275                         self.parse_extension(ext)
276
277         def check_backport_extensions(self, api):
278                 for e in api.extensions.itervalues():
279                         e.backport = True
280                         for t in e.things.itervalues():
281                                 if t.name.endswith(e.ext_type):
282                                         e.backport = False
283                                         break
284
285         def resolve_enum_aliases(self, api):
286                 core_enums = filter((lambda t: t.kind==Thing.ENUM), api.core_things.itervalues())
287                 core_enums_by_value = dict((e.value, None) for e in core_enums)
288
289                 for e in api.extensions.itervalues():
290                         ext_enums = filter((lambda t: t.kind==Thing.ENUM), e.things.itervalues())
291                         enum_suffix = "_"+e.ext_type
292                         for n in ext_enums:
293                                 if n.api_support[api.name].core_version:
294                                         continue
295
296                                 name = n.name
297                                 if name.endswith(enum_suffix):
298                                         name = name[:-len(enum_suffix)]
299                                 ce = api.core_things.get(name)
300                                 if not ce and n.value in core_enums_by_value:
301                                         if core_enums_by_value[n.value] is None:
302                                                 core_enums_by_value[n.value] = filter((lambda e: e.value==n.value), core_enums)
303                                         for c in core_enums_by_value[n.value]:
304                                                 if c.bitmask==n.bitmask:
305                                                         ce = c
306                                                         break
307                                 if ce and ce.value==n.value and ce.name not in n.aliases:
308                                         n.aliases.append(ce.name)
309
310         def resolve_sources(self, api):
311                 for e in api.extensions.itervalues():
312                         for t in e.things.itervalues():
313                                 for a in t.aliases:
314                                         # There are a few cases where a vendor function is aliased to
315                                         # an EXT or ARB function but those are rare and not relevant for
316                                         # our use
317                                         alias = api.core_things.get(a)
318                                         if alias:
319                                                 sources = alias.api_support[api.name].sources
320                                                 if t not in sources:
321                                                         sources.append(t)
322
323         def sort_extensions(self):
324                 for a in self.apis.itervalues():
325                         e = a.extensions.get(self.target_ext_name)
326                         if e:
327                                 e.preference = 3
328                 for t in self.things.itervalues():
329                         for s in t.api_support.itervalues():
330                                 s.extensions.sort(key=(lambda e: e.preference), reverse=True)
331
332         def finalize(self):
333                 for a in self.apis.itervalues():
334                         self.check_backport_extensions(a)
335                         self.resolve_enum_aliases(a)
336                         self.resolve_sources(a)
337                 self.sort_extensions()
338
339
340 def detect_core_version(host_api, things):
341         candidates = {}
342         for t in things:
343                 supp = t.api_support.get(host_api.name)
344                 if supp and supp.core_version:
345                         candidates[supp.core_version] = candidates.get(supp.core_version, 0)+1
346
347         if candidates:
348                 candidates = list((v, k) for k, v in candidates.items())
349                 if len(candidates)>1:
350                         candidates.sort(reverse=True)
351                         if candidates[1][0]+1>=candidates[0][0]:
352                                 print "Warning: multiple likely core version candidates: %s %s"%(candidates[0][1], candidates[1][1])
353                 return candidates[0][1]
354
355 def detect_deprecated_version(host_api, things):
356         min_version = None
357         for t in things:
358                 supp = t.api_support.get(host_api.name)
359                 if supp and supp.deprecated_version:
360                         if min_version is None:
361                                 min_version = supp.deprecated_version
362                         else:
363                                 min_version = min(min_version, supp.deprecated_version)
364                 else:
365                         return None
366
367         return min_version
368
369 def detect_backport_extension(host_api, target_ext, things):
370         candidates = []
371         for t in things:
372                 supp = t.api_support.get(host_api.name)
373                 if supp and supp.core_version:
374                         for e in supp.extensions:
375                                 if e.backport and e not in candidates:
376                                         candidates.append(e)
377
378         if len(candidates)>1:
379                 print "Warning: multiple backport extension candidates: %s"%(" ".join(e.name for e in candidates))
380
381         for e in candidates:
382                 if e.base_name==target_ext.base_name:
383                         return e
384
385         if len(candidates)==1:
386                 print "Warning: potential backport extension has mismatched name: %s"%candidates[0].name
387
388 def collect_extensions(thing, api, exts):
389         supp = thing.api_support.get(api)
390         if not supp:
391                 return
392
393         for e in supp.extensions:
394                 if not e.backport and e not in exts:
395                         exts.append(e)
396
397         for s in supp.sources:
398                 collect_extensions(s, api, exts)
399
400 def detect_source_extension(host_api, target_ext, things):
401         if target_ext.name in host_api.extensions:
402                 return target_ext
403
404         things_by_ext = {}
405         for t in things:
406                 exts = []
407                 collect_extensions(t, host_api.name, exts)
408                 for e in exts:
409                         things_by_ext.setdefault(e, []).append(t)
410
411         largest_ext = None
412         largest_count = 0
413         for e, t in things_by_ext.iteritems():
414                 count = len(t)
415                 if count>largest_count:
416                         largest_ext = e
417                         largest_count = count
418
419         return largest_ext
420
421
422 class SourceGenerator:
423         def __init__(self, host_api, target_ext, things):
424                 self.host_api = host_api
425                 self.api_prefix = "GL"
426                 if self.host_api.name=="gles2":
427                         self.api_prefix = "GL_ES"
428                 self.target_ext = target_ext
429                 self.funcs = filter((lambda t: t.kind==Thing.FUNCTION), things)
430                 self.funcs.sort(key=(lambda f: f.name))
431                 self.func_typedefs = dict((f.name, "FPtr_"+f.name) for f in self.funcs)
432                 self.enums = filter((lambda t: t.kind==Thing.ENUM), things)
433                 self.enums.sort(key=(lambda e: e.value))
434                 self.core_version = detect_core_version(host_api, things)
435                 self.deprecated_version = detect_deprecated_version(host_api, things)
436                 self.backport_ext = detect_backport_extension(host_api, target_ext, things);
437                 self.source_ext = detect_source_extension(host_api, target_ext, things)
438
439         def write_header_intro(self, out):
440                 out.write("#ifndef MSP_GL_%s_\n"%self.target_ext.name.upper())
441                 out.write("#define MSP_GL_%s_\n"%self.target_ext.name.upper())
442
443                 out.write("""
444 #include <msp/gl/extension.h>
445 #include <msp/gl/gl.h>
446
447 namespace Msp {
448 namespace GL {
449
450 """)
451
452         def write_enum_definitions(self, out):
453                 enums_by_category = {}
454                 for e in self.enums:
455                         cat = None
456                         supp = e.api_support.get(self.host_api.name)
457                         if supp:
458                                 if supp.core_version:
459                                         cat = "%s_%s"%(self.api_prefix, supp.core_version.as_define())
460                                 elif supp.extensions:
461                                         cat = "GL_"+supp.extensions[0].name
462                         enums_by_category.setdefault(cat, []).append(e)
463
464                 for cat in sorted(enums_by_category.keys()):
465                         if cat:
466                                 out.write("#ifndef %s\n"%cat)
467                         for e in enums_by_category[cat]:
468                                 out.write("#define %s 0x%04X\n"%(e.name, e.value))
469                         if cat:
470                                 out.write("#endif\n")
471                         out.write("\n")
472
473         def write_function_pointer_declarations(self, out):
474                 for f in self.funcs:
475                         typedef = self.func_typedefs[f.name]
476                         out.write("typedef %s (APIENTRY *%s)(%s);\n"%(f.return_type, typedef, ", ".join(f.params)))
477                         out.write("extern %s %s;\n"%(typedef, f.name))
478                         out.write("\n")
479
480         def write_header_outro(self, out):
481                 out.write("""
482 } // namespace GL
483 } // namespace Msp
484
485 #endif
486 """)
487
488         def write_source_intro(self, out):
489                 out.write("#include \"%s.h\"\n"%self.target_ext.name.lower())
490                 if self.funcs:
491                         out.write("""
492 #ifdef __APPLE__
493 #define GET_PROC_ADDRESS(x) &::x
494 #else
495 #define GET_PROC_ADDRESS(x) get_proc_address(#x)
496 #endif
497
498 #ifdef _WIN32
499 #define GET_PROC_ADDRESS_1_1(x) &::x
500 #else
501 #define GET_PROC_ADDRESS_1_1(x) GET_PROC_ADDRESS(x)
502 #endif
503 """)
504                 out.write("""
505 namespace Msp {
506 namespace GL {
507
508 """)
509
510         def write_function_pointer_definitions(self, out):
511                 for f in self.funcs:
512                         out.write("%s %s = 0;\n"%(self.func_typedefs[f.name], f.name))
513
514         def write_init_function(self, out):
515                 out.write("\nExtension::SupportLevel init_%s()\n{\n"%self.target_ext.name.lower())
516                 if self.core_version:
517                         out.write("\tif(is_disabled(\"GL_%s\"))\n\t\treturn Extension::UNSUPPORTED;\n"%self.target_ext.name)
518                         out.write("#if !defined(__APPLE__) || defined(%s_%s)\n"%(self.api_prefix, self.core_version.as_define()))
519                         out.write("\tif(")
520                         if self.backport_ext:
521                                 out.write("is_supported(\"GL_%s\") || "%self.backport_ext.name)
522                         out.write("is_supported(%r"%self.core_version)
523                         if self.deprecated_version:
524                                 out.write(", %r"%self.deprecated_version)
525                         out.write("))\n\t{\n")
526                         for f in self.funcs:
527                                 supp = f.api_support.get(self.host_api.name)
528                                 if supp:
529                                         gpa_suffix = ""
530                                         if supp.core_version is not None and supp.core_version<=Version(1, 1):
531                                                 gpa_suffix = "_1_1"
532                                         out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS%s(%s));\n"%(f.name, self.func_typedefs[f.name], gpa_suffix, f.name))
533                         out.write("\t\treturn Extension::CORE;\n")
534                         out.write("\t}\n")
535                         out.write("#endif\n")
536                 if self.source_ext and self.source_ext!=self.backport_ext:
537                         out.write("#if !defined(__APPLE__) || defined(GL_%s)\n"%self.target_ext.name)
538                         out.write("\tif(is_supported(\"GL_%s\"))\n\t{\n"%(self.source_ext.name))
539                         for f in self.funcs:
540                                 supp = f.api_support.get(self.host_api.name)
541                                 if supp and supp.sources:
542                                         src = None
543                                         for s in supp.sources:
544                                                 if s.name.endswith(self.source_ext.ext_type):
545                                                         src = s
546                                                         break
547                                         if not src:
548                                                 src = supp.sources[0]
549                                 else:
550                                         src = f
551
552                                 if self.host_api.name in src.api_support:
553                                         if not src.name.endswith(self.source_ext.ext_type):
554                                                 print "Warning: %s does not match extension type %s"%(src.name, self.source_ext.ext_type)
555                                         out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS(%s));\n"%(f.name, self.func_typedefs[f.name], src.name))
556                         out.write("\t\treturn Extension::EXTENSION;\n")
557                         out.write("\t}\n")
558                         out.write("#endif\n")
559                 out.write("\treturn Extension::UNSUPPORTED;\n")
560                 out.write("}\n")
561
562         def write_source_outro(self, out):
563                 out.write("""
564 } // namespace GL
565 } // namespace Msp
566 """)
567
568         def write_header(self, fn):
569                 out = file(fn, "w")
570                 self.write_header_intro(out)
571                 self.write_enum_definitions(out)
572                 self.write_function_pointer_declarations(out)
573                 out.write("extern Extension %s;\n"%self.target_ext.name)
574                 self.write_header_outro(out)
575
576         def write_source(self, fn):
577                 out = file(fn, "w")
578                 self.write_source_intro(out)
579                 self.write_function_pointer_definitions(out)
580                 self.write_init_function(out)
581                 ext_name = self.target_ext.name
582                 out.write("\nExtension %s(\"GL_%s\", init_%s);\n"%(ext_name, ext_name, ext_name.lower()))
583                 self.write_source_outro(out)
584
585
586 class ExtensionParser:
587         def __init__(self, host_api):
588                 self.host_api = host_api
589                 self.target_ext = None
590                 self.core_version = None
591                 self.deprecated_version = None
592                 self.secondary_exts = []
593                 self.backport_ext = None
594                 self.ignore_things = []
595
596         def parse(self, fn):
597                 for line in open(fn):
598                         line = line.strip()
599                         if line.startswith("#"):
600                                 continue
601
602                         parts = line.split()
603                         keyword = parts[0]
604
605                         if keyword=="extension":
606                                 self.target_ext = parts[1]
607                         elif keyword=="core_version":
608                                 if parts[1]==self.host_api:
609                                         self.core_version = Version(*map(int, parts[2].split('.')))
610                         elif keyword=="deprecated":
611                                 if parts[1]==self.host_api:
612                                         self.deprecated_version = Version(*map(int, parts[2].split('.')))
613                         elif keyword=="secondary":
614                                 self.secondary_exts.append(parts[1])
615                         elif keyword=="backport":
616                                 self.backport_ext = parts[1]
617                         elif keyword=="ignore":
618                                 self.ignore_things.append(parts[1])
619
620
621 def get_extension(api_map, ext_name):
622         main_api = api_map["gl"]
623         ext = main_api.extensions.get(ext_name)
624         if ext:
625                 return ext
626
627         for a in api_map.itervalues():
628                 ext = a.extensions.get(ext_name)
629                 if ext:
630                         return ext
631
632 def collect_things(host_api, target_ext, secondary, ignore):
633         ext_things = [t for n, t in target_ext.things.iteritems() if n not in ignore]
634         core_things = target_ext.api.core_things
635
636         things = []
637         for t in ext_things:
638                 found_in_core = False
639                 for a in t.aliases:
640                         if a in core_things:
641                                 things.append(core_things[a])
642                                 found_in_core = True
643                 if not found_in_core:
644                         things.append(t)
645
646         for s in secondary:
647                 for t in s.things.itervalues():
648                         for a in t.aliases:
649                                 if a in core_things and core_things[a] not in things:
650                                         things.append(core_things[a])
651
652         return things
653
654 def main():
655         if len(sys.argv)<2:
656                 print """Usage:
657   extgen.py [api] <extfile> [<outfile>]
658
659 Reads gl.xml and generates C++ source files to use an OpenGL extension
660 described in <extfile>.  If <outfile> is absent, the extension's lowercased
661 name is used.  Anything after the last dot in <outfile> is removed and
662 replaced with cpp and h."""
663                 sys.exit(0)
664
665         host_api_name = "gl"
666
667         i = 1
668         if sys.argv[i].startswith("gl"):
669                 host_api_name = sys.argv[i]
670                 i += 1
671
672         ext_parser = ExtensionParser(host_api_name)
673         ext_parser.parse(sys.argv[i])
674         i += 1
675
676         if i<len(sys.argv):
677                 out_base = os.path.splitext(sys.argv[i])[0]
678         else:
679                 out_base = ext_parser.target_ext.lower()
680
681         xml_parser = GlXmlParser(host_api_name, ext_parser.target_ext)
682         xml_parser.parse_file("gl.xml")
683         xml_parser.parse_file("gl.fixes.xml")
684         xml_parser.finalize()
685
686         host_api = xml_parser.apis[host_api_name]
687         target_ext = get_extension(xml_parser.apis, ext_parser.target_ext)
688         secondary_exts = [get_extension(xml_parser.apis, s) for s in ext_parser.secondary_exts]
689         things = collect_things(host_api, target_ext, secondary_exts, ext_parser.ignore_things)
690
691         generator = SourceGenerator(host_api, target_ext, things)
692         if ext_parser.core_version:
693                 generator.core_version = ext_parser.core_version
694         if ext_parser.deprecated_version:
695                 generator.deprecated_version = ext_parser.deprecated_version
696         generator.write_header(out_base+".h")
697         generator.write_source(out_base+".cpp")
698
699 if __name__=="__main__":
700         main()