]> git.tdb.fi Git - libs/gl.git/blob - scripts/extgen.py
Don't crash if an explicitly specified backport extension is not found
[libs/gl.git] / scripts / extgen.py
1 #!/usr/bin/python
2
3 import sys
4 import os
5 import xml.dom
6 import xml.dom.minidom
7 import itertools
8
9 ### Command line processing ###
10
11 if len(sys.argv)<2:
12         print """Usage:
13   extgen.py [api] <extension> [<core_version>] [<secondary> ...]
14   extgen.py [api] <extfile> [<outfile>]
15
16 Reads gl.xml and generates files to use <extension>.  Any promoted functions
17 are exposed with their promoted names.  If <secondary> extensions are given,
18 any promoted functions from those are pulled in as well.  <core_version> can
19 be given to override the version where <extension> was promoted to core.
20
21 In the second form, the parameters are read from <extfile>.  If <outfile> is
22 absent, the extension's lowercased name is used.  Anything after the last dot
23 in <outfile> is removed and replaced with cpp and h."""
24         sys.exit(0)
25
26 target_api = "gl"
27
28 i = 1
29 if sys.argv[i].startswith("gl"):
30         target_api = sys.argv[i]
31         i += 1
32
33 target_ext = sys.argv[i]
34 backport_ext = None
35 deprecated_version = None
36 out_base = None
37 ignore_things = []
38 if target_ext.endswith(".glext"):
39         fn = target_ext
40         target_ext = None
41         core_version = None
42         secondary = []
43         for line in open(fn):
44                 parts = line.split()
45                 if parts[0]=="extension":
46                         target_ext = parts[1]
47                 elif parts[0]=="core_version":
48                         if parts[1]==target_api:
49                                 core_version = parts[2]
50                 elif parts[0]=="deprecated":
51                         if parts[1]==target_api:
52                                 deprecated_version = parts[2]
53                 elif parts[0]=="secondary":
54                         secondary.append(parts[1])
55                 elif parts[0]=="backport":
56                         backport_ext = parts[1]
57                 elif parts[0]=="ignore":
58                         ignore_things.append(parts[1])
59         if i+1<len(sys.argv):
60                 out_base = os.path.splitext(sys.argv[i+1])[0]
61 else:
62         secondary = sys.argv[i+1:]
63         core_version = None
64         if secondary and secondary[0][0].isdigit():
65                 core_version = secondary.pop(0)
66
67 ext_type = target_ext.split('_')[0]
68
69 if core_version:
70         core_version = map(int, core_version.split('.'))
71
72 if deprecated_version:
73         deprecated_version = map(int, deprecated_version.split('.'))
74
75 if not out_base:
76         out_base = target_ext.lower()
77
78 ### XML file parsing ###
79
80 class Thing:
81         FUNCTION = 1
82         ENUM = 2
83
84         def __init__(self, name, kind):
85                 self.name = name
86                 self.kind = kind
87                 self.version = None
88                 self.deprecated_version = None
89                 self.extension = None
90                 self.supported_apis = {}
91                 self.deprecated = {}
92                 self.aliases = []
93                 self.sources = []
94
95 class Function(Thing):
96         def __init__(self, name):
97                 Thing.__init__(self, name, Thing.FUNCTION)
98                 self.return_type = "void"
99                 self.params = []
100                 self.typedef = None
101
102 class Enum(Thing):
103         def __init__(self, name):
104                 Thing.__init__(self, name, Thing.ENUM)
105                 self.value = 0
106                 self.bitmask = (name.endswith("_BIT") or "_BIT_" in name)
107
108 class Extension:
109         def __init__(self, name):
110                 self.name = name
111                 self.supported_apis = []
112                 underscore = name.find('_')
113                 self.ext_type = name[0:underscore]
114                 self.base_name = name[underscore+1:]
115
116 extensions = {}
117 things = {}
118
119 def get_nested_elements(elem, path):
120         childElements = [c for c in elem.childNodes if c.nodeType==xml.dom.Node.ELEMENT_NODE]
121         if '/' in path:
122                 head, tail = path.split('/', 1)
123                 result = []
124                 for c in childElements:
125                         if c.tagName==head:
126                                 result += get_nested_elements(c, tail)
127                 return result
128         else:
129                 return [c for c in childElements if c.tagName==path]
130
131 def get_first_child(elem, tag):
132         for c in elem.childNodes:
133                 if c.nodeType==xml.dom.Node.ELEMENT_NODE and c.tagName==tag:
134                         return c
135         return None
136
137 def get_text_contents(node):
138         result = ""
139         for c in node.childNodes:
140                 if c.nodeType==xml.dom.Node.TEXT_NODE or c.nodeType==xml.dom.Node.CDATA_SECTION_NODE:
141                         result += c.data
142                 else:
143                         result += get_text_contents(c)
144         return result
145
146 def parse_command(cmd):
147         proto = get_first_child(cmd, "proto")
148         name = get_text_contents(get_first_child(proto, "name"))
149         func = things.get(name)
150         if not func:
151                 func = Function(name)
152                 things[name] = func
153
154         aliases = get_nested_elements(cmd, "alias")
155         func.aliases = [a.getAttribute("name") for a in aliases]
156
157         ptype = get_first_child(proto, "ptype")
158         if ptype:
159                 func.return_type = get_text_contents(ptype)
160         else:
161                 for c in proto.childNodes:
162                         if c.nodeType==xml.dom.Node.TEXT_NODE and c.data.strip():
163                                 func.return_type = c.data.strip()
164                                 break
165
166         params = get_nested_elements(cmd, "param")
167         func.params = map(get_text_contents, params)
168
169 def parse_enum(en):
170         name = en.getAttribute("name")
171         enum = things.get(name)
172         if not enum:
173                 enum = Enum(name)
174                 things[name] = enum
175
176         enum.value = int(en.getAttribute("value"), 16)
177
178 def parse_feature(feat):
179         api = feat.getAttribute("api")
180         version = feat.getAttribute("number")
181         if version:
182                 version = map(int, version.split('.'))
183         else:
184                 version = None
185
186         requires = get_nested_elements(feat, "require")
187         for req in requires:
188                 commands = get_nested_elements(req, "command")
189                 enums = get_nested_elements(req, "enum")
190                 for t in itertools.chain(commands, enums):
191                         name = t.getAttribute("name")
192                         thing = things.get(name)
193                         if thing:
194                                 thing.supported_apis.setdefault(api, version)
195
196         if not api or api==target_api:
197                 removes = get_nested_elements(feat, "remove")
198                 for rem in removes:
199                         profile = rem.getAttribute("profile")
200                         commands = get_nested_elements(rem, "command")
201                         enums = get_nested_elements(rem, "enum")
202
203                         for t in itertools.chain(commands, enums):
204                                 name = t.getAttribute("name")
205                                 if name in things:
206                                         if profile!="core":
207                                                 del things[name]
208                                         else:
209                                                 things[name].deprecated.setdefault(api, version)
210
211 def parse_extension(ext):
212         ext_name = ext.getAttribute("name")
213         if ext_name.startswith("GL_"):
214                 ext_name = ext_name[3:]
215
216         supported = ext.getAttribute("supported").split('|')
217         if target_api not in supported and ext_name!=target_ext:
218                 return
219
220         extension = extensions.get(ext_name)
221         if not extension:
222                 extension = Extension(ext_name)
223                 extensions[ext_name] = extension
224
225         extension.supported_apis = supported
226
227         requires = get_nested_elements(ext, "require")
228         for req in requires:
229                 api = req.getAttribute("api")
230                 if api:
231                         supported = [api]
232                 else:
233                         supported = extension.supported_apis
234
235                 commands = get_nested_elements(req, "command")
236                 enums = get_nested_elements(req, "enum")
237                 for t in itertools.chain(commands, enums):
238                         name = t.getAttribute("name")
239                         if name in ignore_things:
240                                 continue
241
242                         thing = things.get(name)
243                         if thing:
244                                 if thing.extension and extension.name!=target_ext:
245                                         if thing.extension.ext_type=="ARB" or thing.extension.name==target_ext:
246                                                 continue
247                                         if thing.extension.ext_type=="EXT" and extension.ext_type!="ARB":
248                                                 continue
249
250                                 thing.extension = extension
251                                 for a in supported:
252                                         thing.supported_apis.setdefault(a, "ext")
253
254 def parse_file(fn):
255         doc = xml.dom.minidom.parse(fn)
256         root = doc.documentElement
257
258         commands = get_nested_elements(root, "commands/command")
259         for cmd in commands:
260                 parse_command(cmd)
261
262         enums = get_nested_elements(root, "enums/enum")
263         for en in enums:
264                 parse_enum(en)
265
266         features = get_nested_elements(root, "feature")
267         for feat in features:
268                 parse_feature(feat)
269
270         extensions = get_nested_elements(root, "extensions/extension")
271         for ext in extensions:
272                 parse_extension(ext)
273
274 parse_file("gl.xml")
275 parse_file("gl.fixes.xml")
276
277 ### Additional processing ###
278
279 if target_ext in extensions:
280         target_ext = extensions[target_ext]
281 else:
282         print "Extension %s not found"%target_ext
283         sys.exit(1)
284
285 # Find aliases for enums
286 enums = [t for t in things.itervalues() if t.kind==Thing.ENUM]
287 core_enums = [e for e in enums if any(v!="ext" for v in e.supported_apis.itervalues())]
288 core_enums_by_value = dict((e.value, None) for e in core_enums)
289
290 def get_key_api(things):
291         common_apis = set(target_ext.supported_apis)
292         for t in things:
293                 common_apis.intersection_update(t.supported_apis.keys())
294         if common_apis:
295                 return common_apis.pop()
296         else:
297                 return target_api
298
299 for e in enums:
300         if all(v=="ext" for v in e.supported_apis.values()) and e.value in core_enums_by_value:
301                 if core_enums_by_value[e.value] is None:
302                         candidates = [ce for ce in core_enums if ce.value==e.value]
303                         key_api = get_key_api(candidates)
304                         core_enums_by_value[e.value] = list(sorted(candidates, key=(lambda x: x.supported_apis.get(key_api, "ext"))))
305                 for ce in core_enums_by_value[e.value]:
306                         if ce.bitmask==e.bitmask:
307                                 e.aliases.append(ce.name)
308                                 break
309
310 # Create references from core things to their extension counterparts
311 for t in things.itervalues():
312         if t.extension:
313                 for a in t.aliases:
314                         alias = things.get(a)
315                         if alias:
316                                 if target_api in t.supported_apis:
317                                         alias.sources.insert(0, t)
318                                 else:
319                                         alias.sources.append(t)
320
321 # Find the things we want to include in this extension
322 def is_relevant(t):
323         # Unpromoted extension things are relevant
324         if t.extension and t.extension==target_ext and not t.aliases:
325                 return True
326
327         # Core things promoted from the extension are also relevant
328         for s in t.sources:
329                 if s.extension==target_ext or s.extension.name in secondary:
330                         return True
331
332         return False
333
334 funcs = [t for t in things.itervalues() if t.kind==Thing.FUNCTION and is_relevant(t)]
335 funcs.sort(key=(lambda f: f.name))
336 enums = filter(is_relevant, enums)
337 enums.sort(key=(lambda e: e.value))
338
339 # Some final preparations for creating the files
340 core_version_candidates = {}
341 min_deprecated_version = [999, 0]
342 backport_ext_candidates = []
343 for t in itertools.chain(funcs, enums):
344         if target_api in t.supported_apis and t.supported_apis[target_api]!="ext":
345                 t.version = t.supported_apis[target_api]
346                 if t.version:
347                         ver_tuple = tuple(t.version)
348                         core_version_candidates[ver_tuple] = core_version_candidates.get(ver_tuple, 0)+1
349
350         if target_api in t.deprecated:
351                 t.deprecated_version = t.deprecated[target_api]
352                 min_deprecated_version = min(min_deprecated_version, t.deprecated_version)
353         else:
354                 min_deprecated_version = None
355
356         # Things in backport extensions don't acquire an extension suffix
357         if t.extension and not t.name.endswith(ext_type) and target_api in t.supported_apis:
358                 if t.extension not in backport_ext_candidates:
359                         backport_ext_candidates.append(t.extension)
360
361 if not core_version and core_version_candidates:
362         core_version_candidates = list((v, k) for k, v in core_version_candidates.items())
363         if len(core_version_candidates)>1:
364                 core_version_candidates.sort(reverse=True)
365                 if core_version_candidates[1][0]+1>=core_version_candidates[0][0]:
366                         ver0 = core_version_candidates[0][1]
367                         ver1 = core_version_candidates[1][1]
368                         print "Warning: multiple likely core version candidates: %d.%d %d.%d"%(ver0[0], ver0[1], ver1[0], ver1[1])
369         core_version = core_version_candidates[0][1]
370
371 if not deprecated_version:
372         deprecated_version = min_deprecated_version
373
374 if backport_ext:
375         if backport_ext=="none":
376                 backport_ext = None
377         else:
378                 bpe_name = backport_ext
379                 backport_ext = extensions.get(backport_ext)
380
381                 if backport_ext not in backport_ext_candidates:
382                         print "Warning: explicitly specified backport extension %s does not look like a backport extension"%bpe_name
383 elif backport_ext_candidates:
384         if len(backport_ext_candidates)>1:
385                 print "Warning: multiple backport extension candidates: %s"%(" ".join(e.name for e in backport_ext_candidates))
386
387         for e in backport_ext_candidates:
388                 if e.base_name==target_ext.base_name:
389                         backport_ext = e
390
391         if not backport_ext and len(backport_ext_candidates)==1:
392                 print "Warning: potential backport extension has mismatched name: %s"%backport_ext_candidates[0].name
393
394 for f in funcs:
395         f.typedef = "FPtr_%s"%f.name
396
397 if target_api in target_ext.supported_apis:
398         source_ext = target_ext
399 else:
400         candidates = {}
401         for t in itertools.chain(funcs, enums):
402                 for s in t.sources:
403                         if target_api in s.supported_apis:
404                                 candidates[s.extension.name] = candidates.get(s.extension.name, 0)+1
405         if candidates:
406                 source_ext = extensions[max(candidates.iteritems(), key=(lambda x: x[1]))[0]]
407         else:
408                 source_ext = None
409
410 if funcs or enums:
411         any_supported = False
412         all_supported = True
413         for t in itertools.chain(funcs, enums):
414                 if target_api in t.supported_apis:
415                         any_supported = True
416                 else:
417                         all_supported = False
418
419         if not any_supported:
420                 print "Warning: %s is not supported by the target API"%target_ext.name
421         elif not all_supported:
422                 print "Warning: %s is only partially supported by the target API"%target_ext.name
423                 unsupported = ""
424                 label = "Warning: Unsupported tokens: "
425                 for t in itertools.chain(funcs, enums):
426                         if target_api not in t.supported_apis:
427                                 if unsupported and len(label)+len(unsupported)+2+len(t.name)>78:
428                                         print label+unsupported
429                                         label = " "*len(label)
430                                         unsupported = ""
431                                 if unsupported:
432                                         unsupported += ", "
433                                 unsupported += t.name
434                 if unsupported:
435                         print label+unsupported
436
437 ### Output ###
438
439 out = file(out_base+".h", "w")
440 out.write("#ifndef MSP_GL_%s_\n"%target_ext.name.upper())
441 out.write("#define MSP_GL_%s_\n"%target_ext.name.upper())
442
443 out.write("""
444 #include <msp/gl/extension.h>
445 #include <msp/gl/gl.h>
446
447 namespace Msp {
448 namespace GL {
449
450 """)
451
452 if funcs or enums:
453         if funcs:
454                 for f in funcs:
455                         out.write("typedef %s (APIENTRY *%s)(%s);\n"%(f.return_type, f.typedef, ", ".join(f.params)))
456                 out.write("\n")
457
458         if enums:
459                 api_prefix = "GL"
460                 if target_api=="gles2":
461                         api_prefix = "GL_ES"
462
463                 enums_by_category = {}
464                 for e in enums:
465                         cat = None
466                         if e.version:
467                                 cat = api_prefix+"_VERSION_"+"_".join(map(str, e.version))
468                         elif e.extension:
469                                 cat = "GL_"+e.extension.name
470                         enums_by_category.setdefault(cat, []).append(e)
471
472                 for cat in sorted(enums_by_category.keys()):
473                         if cat:
474                                 out.write("#ifndef %s\n"%cat)
475                         for e in enums_by_category[cat]:
476                                 out.write("#define %s 0x%04X\n"%(e.name, e.value))
477                         if cat:
478                                 out.write("#endif\n")
479                         out.write("\n")
480
481         for f in funcs:
482                 out.write("extern %s %s;\n"%(f.typedef, f.name))
483
484 out.write("extern Extension %s;\n"%target_ext.name)
485
486 out.write("""
487 } // namespace GL
488 } // namespace Msp
489
490 #endif
491 """)
492
493 out = file(out_base+".cpp", "w")
494 out.write("#include \"%s.h\"\n"%target_ext.name.lower())
495
496 if funcs:
497         out.write("""
498 #ifdef __APPLE__
499 #define GET_PROC_ADDRESS(x) &::x
500 #else
501 #define GET_PROC_ADDRESS(x) get_proc_address(#x)
502 #endif
503
504 #ifdef _WIN32
505 #define GET_PROC_ADDRESS_1_1(x) &::x
506 #else
507 #define GET_PROC_ADDRESS_1_1(x) GET_PROC_ADDRESS(x)
508 #endif
509 """)
510 out.write("""
511 namespace Msp {
512 namespace GL {
513
514 """)
515
516 for f in funcs:
517         out.write("%s %s = 0;\n"%(f.typedef, f.name))
518
519 out.write("\nExtension::SupportLevel init_%s()\n{\n"%target_ext.name.lower())
520 if core_version:
521         out.write("\tif(is_disabled(\"GL_%s\"))\n\t\treturn Extension::UNSUPPORTED;\n"%target_ext.name)
522         out.write("#if !defined(__APPLE__) || defined(GL_VERSION_%d_%d)\n"%tuple(core_version))
523         out.write("\tif(")
524         if backport_ext:
525                 out.write("is_supported(\"GL_%s\") || "%backport_ext.name)
526         out.write("is_supported(Version(%d, %d)"%tuple(core_version))
527         if deprecated_version:
528                 out.write(", Version(%d, %d)"%tuple(deprecated_version))
529         out.write("))\n\t{\n")
530         for f in funcs:
531                 if target_api in f.supported_apis:
532                         gpa_suffix = ""
533                         if f.version is not None and f.version<=[1, 1]:
534                                 gpa_suffix = "_1_1"
535                         out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS%s(%s));\n"%(f.name, f.typedef, gpa_suffix, f.name))
536         out.write("\t\treturn Extension::CORE;\n")
537         out.write("\t}\n")
538         out.write("#endif\n")
539 if source_ext and source_ext!=backport_ext:
540         out.write("#if !defined(__APPLE__) || defined(GL_%s)\n"%target_ext.name)
541         out.write("\tif(is_supported(\"GL_%s\"))\n\t{\n"%(source_ext.name))
542         for f in funcs:
543                 if f.sources:
544                         src = None
545                         for s in f.sources:
546                                 if s.name.endswith(source_ext.ext_type):
547                                         src = s
548                                         break
549                         if not src:
550                                 src = f.sources[0]
551                 else:
552                         src = f
553
554                 if target_api in src.supported_apis:
555                         if not src.name.endswith(source_ext.ext_type):
556                                 print "Warning: %s does not match extension type %s"%(src.name, source_ext.ext_type)
557                         out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS(%s));\n"%(f.name, f.typedef, src.name))
558         out.write("\t\treturn Extension::EXTENSION;\n")
559         out.write("\t}\n")
560         out.write("#endif\n")
561 out.write("\treturn Extension::UNSUPPORTED;\n")
562 out.write("}\n")
563
564 out.write("\nExtension %s(\"GL_%s\", init_%s);\n"%(target_ext.name, target_ext.name, target_ext.name.lower()))
565
566 out.write("""
567 } // namespace GL
568 } // namespace Msp
569 """)