]> git.tdb.fi Git - libs/gl.git/blob - scripts/extgen.py
Check the flat qualifier from the correct member
[libs/gl.git] / scripts / extgen.py
1 #!/usr/bin/python3
2
3 import sys
4 import os
5 import xml.dom
6 import xml.dom.minidom
7 import itertools
8 import re
9
10 class Version:
11         def __init__(self, *args):
12                 if len(args)==0:
13                         self.major = 0
14                         self.minor = 0
15                 elif len(args)==2:
16                         self.major = args[0]
17                         self.minor = args[1]
18                 else:
19                         raise TypeError("__init__() takes zero or two arguments ({} given)".format(len(args)))
20
21         def __str__(self):
22                 return "{}.{}".format(self.major, self.minor)
23
24         def __repr__(self):
25                 return "Version({}, {})".format(self.major, self.minor)
26
27         def as_define(self):
28                 return "VERSION_{}_{}".format(self.major, self.minor)
29
30         def __lt__(self, other):
31                 if other is None:
32                         return False
33
34                 if self.major!=other.major:
35                         return self.major<other.major
36                 return self.minor<other.minor
37
38         def __gt__(self, other):
39                 if other is None:
40                         return True
41
42                 if self.major!=other.major:
43                         return self.major>other.major
44                 return self.minor>other.minor
45
46         def __le__(self, other):
47                 return not self.__gt__(other)
48
49         def __ge__(self, other):
50                 return not self.__lt__(other)
51
52         def __eq__(self, other):
53                 if other is None:
54                         return False
55
56                 return (self.major==other.major and self.minor==other.minor)
57
58
59 class Thing:
60         FUNCTION = 1
61         ENUM = 2
62
63         class ApiSupport:
64                 def __init__(self):
65                         self.core_version = None
66                         self.deprecated_version = None
67                         self.extensions = []
68                         self.sources = []
69
70         def __init__(self, name, kind):
71                 self.name = name
72                 self.kind = kind
73                 self.aliases = []
74                 self.api_support = {}
75
76         def get_or_create_api_support(self, api):
77                 supp = self.api_support.get(api)
78                 if not supp:
79                         supp = Thing.ApiSupport()
80                         self.api_support[api] = supp
81                 return supp
82
83
84 class Function(Thing):
85         def __init__(self, name):
86                 Thing.__init__(self, name, Thing.FUNCTION)
87                 self.return_type = "void"
88                 self.params = []
89
90
91 r_bitmask = re.compile("_BIT[0-9]*(_|$)")
92
93 class Enum(Thing):
94         def __init__(self, name):
95                 Thing.__init__(self, name, Thing.ENUM)
96                 self.value = 0
97                 self.bitmask = bool(r_bitmask.search(self.name))
98
99
100 class Extension:
101         def __init__(self, name, api):
102                 self.name = name
103                 underscore = name.find('_')
104                 self.ext_type = name[0:underscore]
105                 self.base_name = name[underscore+1:]
106                 self.api = api
107                 self.things = {}
108                 self.preference = 0
109                 if self.ext_type=="EXT":
110                         self.preference = 1
111                 elif self.ext_type=="ARB" or self.ext_type=="OES":
112                         self.preference = 2
113                 self.backport = False
114
115
116 class Api:
117         def __init__(self, name):
118                 self.name = name
119                 self.latest_version = None
120                 self.core_things = {}
121                 self.extensions = {}
122
123
124 def get_nested_elements(elem, path):
125         childElements = [c for c in elem.childNodes if c.nodeType==xml.dom.Node.ELEMENT_NODE]
126         if '/' in path:
127                 head, tail = path.split('/', 1)
128                 result = []
129                 for c in childElements:
130                         if c.tagName==head:
131                                 result += get_nested_elements(c, tail)
132                 return result
133         else:
134                 return [c for c in childElements if c.tagName==path]
135
136 def get_first_child(elem, tag):
137         for c in elem.childNodes:
138                 if c.nodeType==xml.dom.Node.ELEMENT_NODE and c.tagName==tag:
139                         return c
140         return None
141
142 def get_text_contents(node):
143         result = ""
144         for c in node.childNodes:
145                 if c.nodeType==xml.dom.Node.TEXT_NODE or c.nodeType==xml.dom.Node.CDATA_SECTION_NODE:
146                         result += c.data
147                 else:
148                         result += get_text_contents(c)
149         return result
150
151 def get_or_create(map, name, type, *args):
152         obj = map.get(name)
153         if not obj:
154                 obj = type(name, *args)
155                 map[name] = obj
156         return obj
157
158
159 class GlXmlParser:
160         def __init__(self):
161                 self.apis = {}
162                 self.things = {}
163
164         def parse_command(self, cmd):
165                 proto = get_first_child(cmd, "proto")
166                 name = get_text_contents(get_first_child(proto, "name"))
167                 func = get_or_create(self.things, name, Function)
168
169                 aliases = get_nested_elements(cmd, "alias")
170                 func.aliases = [a.getAttribute("name") for a in aliases]
171
172                 ptype = get_first_child(proto, "ptype")
173                 if ptype:
174                         func.return_type = get_text_contents(ptype)
175                 else:
176                         for c in proto.childNodes:
177                                 if c.nodeType==xml.dom.Node.TEXT_NODE and c.data.strip():
178                                         func.return_type = c.data.strip()
179                                         break
180
181                 params = get_nested_elements(cmd, "param")
182                 func.params = map(get_text_contents, params)
183
184         def parse_enum(self, en):
185                 name = en.getAttribute("name")
186                 enum = get_or_create(self.things, name, Enum)
187
188                 enum.value = int(en.getAttribute("value"), 16)
189
190                 alias = en.getAttribute("alias")
191                 if alias:
192                         enum.aliases.append(alias)
193
194         def parse_feature(self, feat):
195                 api_name = feat.getAttribute("api")
196                 api = get_or_create(self.apis, api_name, Api)
197
198                 version = feat.getAttribute("number")
199                 if version:
200                         version = Version(*map(int, version.split('.')))
201                 else:
202                         version = None
203
204                 requires = get_nested_elements(feat, "require")
205                 for req in requires:
206                         commands = get_nested_elements(req, "command")
207                         enums = get_nested_elements(req, "enum")
208                         for t in itertools.chain(commands, enums):
209                                 name = t.getAttribute("name")
210                                 thing = self.things.get(name)
211                                 if thing:
212                                         supp = thing.get_or_create_api_support(api.name)
213                                         if not supp.core_version or version<supp.core_version:
214                                                 supp.core_version = version
215                                         api.core_things[thing.name] = thing
216
217                 removes = get_nested_elements(feat, "remove")
218                 for rem in removes:
219                         commands = get_nested_elements(rem, "command")
220                         enums = get_nested_elements(rem, "enum")
221
222                         for t in itertools.chain(commands, enums):
223                                 name = t.getAttribute("name")
224                                 thing = self.things.get(name)
225                                 if thing:
226                                         supp = thing.get_or_create_api_support(api.name)
227                                         supp.deprecated_version = version
228
229         def parse_extension(self, ext):
230                 ext_things_by_api = {}
231                 requires = get_nested_elements(ext, "require")
232                 for req in requires:
233                         api = req.getAttribute("api")
234                         ext_things = ext_things_by_api.setdefault(api, [])
235
236                         commands = get_nested_elements(req, "command")
237                         enums = get_nested_elements(req, "enum")
238                         for t in itertools.chain(commands, enums):
239                                 name = t.getAttribute("name")
240                                 thing = self.things.get(name)
241                                 if thing:
242                                         ext_things.append(thing)
243
244                 ext_name = ext.getAttribute("name")
245                 if ext_name.startswith("GL_"):
246                         ext_name = ext_name[3:]
247
248                 common_things = ext_things_by_api.get("", [])
249                 supported = ext.getAttribute("supported").split('|')
250                 for s in supported:
251                         api = self.apis.get(s)
252                         if not api:
253                                 continue
254
255                         ext = get_or_create(api.extensions, ext_name, Extension, api)
256                         api_things = ext_things_by_api.get(s, [])
257                         for t in itertools.chain(common_things, api_things):
258                                 ext.things[t.name] = t
259                                 t.get_or_create_api_support(api.name).extensions.append(ext)
260
261         def parse_file(self, fn):
262                 doc = xml.dom.minidom.parse(fn)
263                 root = doc.documentElement
264
265                 commands = get_nested_elements(root, "commands/command")
266                 for cmd in commands:
267                         self.parse_command(cmd)
268
269                 enums = get_nested_elements(root, "enums/enum")
270                 for en in enums:
271                         self.parse_enum(en)
272
273                 features = get_nested_elements(root, "feature")
274                 for feat in features:
275                         self.parse_feature(feat)
276
277                 extensions = get_nested_elements(root, "extensions/extension")
278                 for ext in extensions:
279                         self.parse_extension(ext)
280
281         def check_backport_extensions(self, api):
282                 for e in api.extensions.values():
283                         if e.ext_type!="ARB":
284                                 continue
285
286                         e.backport = True
287                         for t in e.things.values():
288                                 if t.name.endswith(e.ext_type):
289                                         e.backport = False
290                                         break
291
292         def resolve_enum_aliases(self, api):
293                 for e in api.extensions.values():
294                         enum_suffix = "_"+e.ext_type
295                         for n in (t for t in e.things.values() if t.kind==Thing.ENUM):
296                                 if n.api_support[api.name].core_version:
297                                         continue
298
299                                 name = n.name
300                                 if name.endswith(enum_suffix):
301                                         name = name[:-len(enum_suffix)]
302                                 ce = api.core_things.get(name)
303                                 if ce and ce.value==n.value and ce.name not in n.aliases:
304                                         n.aliases.append(ce.name)
305
306         def resolve_sources(self, api):
307                 for e in api.extensions.values():
308                         for t in e.things.values():
309                                 for a in t.aliases:
310                                         # There are a few cases where a vendor function is aliased to
311                                         # an EXT or ARB function but those are rare and not relevant for
312                                         # our use
313                                         alias = api.core_things.get(a)
314                                         if alias:
315                                                 sources = alias.api_support[api.name].sources
316                                                 if t not in sources:
317                                                         sources.append(t)
318
319         def sort_extensions(self):
320                 for t in self.things.values():
321                         for s in t.api_support.values():
322                                 s.extensions.sort(key=(lambda e: e.preference), reverse=True)
323
324         def finalize(self):
325                 for a in self.apis.values():
326                         self.check_backport_extensions(a)
327                         self.resolve_enum_aliases(a)
328                         self.resolve_sources(a)
329                 self.sort_extensions()
330
331
332 def detect_core_version(host_api, things, debug=None):
333         max_version = Version(1, 0)
334         max_count = 0
335         lower_count = 0
336         missing = []
337         for t in things:
338                 supp = t.api_support.get(host_api.name)
339                 if supp and supp.core_version:
340                         if supp.core_version>max_version:
341                                 max_version = supp.core_version
342                                 lower_count += max_count
343                                 max_count = 1
344                         elif supp.core_version==max_version:
345                                 max_count += 1
346                         else:
347                                 lower_count += 1
348                 else:
349                         missing.append(t)
350
351         if lower_count>max_count or (missing and len(missing)*2<lower_count+max_count):
352                 print("Warning: Inconsistent core version {}".format(max_version))
353
354         if missing:
355                 if debug:
356                         print("---")
357                         print("{} things missing from core:".format(len(missing)))
358                         for t in missing:
359                                 print("  "+t.name)
360                 return None
361
362         return max_version
363
364 def detect_deprecated_version(host_api, things, debug):
365         min_version = None
366         deprecated = []
367         for t in things:
368                 supp = t.api_support.get(host_api.name)
369                 if supp and supp.deprecated_version:
370                         if min_version is None:
371                                 min_version = supp.deprecated_version
372                         else:
373                                 min_version = min(min_version, supp.deprecated_version)
374                         deprecated.append(t)
375
376         if min_version and len(deprecated)*2<len(things):
377                 print("Warning: Inconsistent deprecation version {}".format(min_version))
378                 if debug:
379                         print("---")
380                         print("{} things are deprecated:".format(len(deprecated)))
381                         for t in deprecated:
382                                 print("  "+t.name)
383
384         return min_version
385
386 def detect_backport_extension(host_api, things):
387         candidates = []
388         for t in things:
389                 supp = t.api_support.get(host_api.name)
390                 if supp and supp.core_version:
391                         for e in supp.extensions:
392                                 if e.backport and e not in candidates:
393                                         candidates.append(e)
394
395         total_count = len(things)
396         best_ext = None
397         best_count = 0
398         for e in candidates:
399                 count = sum(t.name in e.things for t in things)
400                 if count==total_count:
401                         return e
402                 elif count>best_count:
403                         best_ext = e
404                         best_count = count
405
406         if total_count and best_count*2>=total_count:
407                 print("Warning: Inconsistent backport extension {}".format(best_ext.name))
408
409 def collect_extensions(thing, api, exts):
410         supp = thing.api_support.get(api)
411         if not supp:
412                 return
413
414         for e in supp.extensions:
415                 if not e.backport and e.ext_type!="MSP" and e not in exts:
416                         exts.append(e)
417
418         for s in supp.sources:
419                 collect_extensions(s, api, exts)
420
421 def detect_source_extension(host_api, things, debug=False):
422         things_by_ext = {}
423         for t in things:
424                 exts = []
425                 collect_extensions(t, host_api.name, exts)
426                 for e in exts:
427                         things_by_ext.setdefault(e, []).append(t)
428
429         if debug:
430                 print("---")
431                 print("Looking for {} things in {} extensions".format(len(things), len(things_by_ext)))
432
433         extensions = []
434         keep_exts = 0
435         base_version = None
436         recheck_base_version = True
437         missing = set(things)
438         while 1:
439                 if recheck_base_version:
440                         max_version = Version(1, 0)
441                         for t in missing:
442                                 supp = t.api_support.get(host_api.name)
443                                 if supp and supp.core_version and max_version:
444                                         max_version = max(max_version, supp.core_version)
445                                 else:
446                                         max_version = None
447
448                         if max_version:
449                                 if not base_version or max_version<base_version:
450                                         base_version = max_version
451                                         keep_exts = len(extensions)
452                         elif not base_version:
453                                 keep_exts = len(extensions)
454
455                         recheck_base_version = False
456
457                 if not missing or not things_by_ext:
458                         break
459
460                 largest_ext = None
461                 largest_count = 0
462                 for e, t in things_by_ext.items():
463                         count = len(t)
464                         if count>largest_count:
465                                 largest_ext = e
466                                 largest_count = count
467                         elif count==largest_count and e.preference>largest_ext.preference:
468                                 largest_ext = e
469
470                 if debug:
471                         print("Found {} things in {}".format(largest_count, largest_ext.name))
472
473                 extensions.append(largest_ext)
474                 for t in things_by_ext[largest_ext]:
475                         missing.remove(t)
476
477                         supp = t.api_support.get(host_api.name)
478                         if supp and supp.core_version==base_version:
479                                 recheck_base_version = True
480
481                 del things_by_ext[largest_ext]
482                 for e in list(things_by_ext.keys()):
483                         unseen = [t for t in things_by_ext[e] if t in missing]
484                         if unseen:
485                                 things_by_ext[e] = unseen
486                         else:
487                                 del things_by_ext[e]
488
489         if not missing:
490                 return None, extensions
491         elif base_version:
492                 if debug:
493                         print("Found remaining things in version {}".format(base_version))
494                         if keep_exts<len(extensions):
495                                 print("Discarding {} extensions that do not improve base version".format(len(extensions)-keep_exts))
496                 del extensions[keep_exts:]
497                 return base_version, extensions
498         else:
499                 if debug:
500                         print("{} things still missing:".format(len(missing)))
501                         for t in missing:
502                                 print("  "+t.name)
503                 return None, None
504
505
506 class SourceGenerator:
507         def __init__(self, host_api, ext_name, things, optional_things, debug=False):
508                 self.host_api = host_api
509                 self.api_prefix = "GL"
510                 if self.host_api.name=="gles2":
511                         self.api_prefix = "GL_ES"
512                 self.ext_name = ext_name
513                 all_things = things+optional_things
514                 self.funcs = [t for t in all_things if t.kind==Thing.FUNCTION]
515                 self.funcs.sort(key=(lambda f: f.name))
516                 self.func_typedefs = dict((f.name, "FPtr_"+f.name) for f in self.funcs)
517                 self.enums = [t for t in all_things if t.kind==Thing.ENUM]
518                 self.enums.sort(key=(lambda e: e.value))
519                 self.core_version = detect_core_version(host_api, things, debug)
520                 self.deprecated_version = detect_deprecated_version(host_api, things, debug)
521                 self.backport_ext = detect_backport_extension(host_api, things)
522                 b, e = detect_source_extension(host_api, things, debug)
523                 self.base_version = b
524                 self.source_exts = e
525
526                 if not self.core_version and not self.backport_ext and not self.source_exts:
527                         print("Warning: Not supportable on host API")
528
529         def dump_info(self):
530                 print("--- Extension information ---")
531                 print("Extension {}".format(self.ext_name))
532                 print("Core {}".format(self.core_version))
533                 print("Deprecated {}".format(self.deprecated_version))
534                 if self.backport_ext:
535                         print("Backport {}".format(self.backport_ext.name))
536                 if self.source_exts:
537                         names = [e.name for e in self.source_exts]
538                         if self.base_version:
539                                 names.insert(0, "Version {}".format(self.base_version))
540                         print("Sources {}".format(", ".join(names)))
541
542         def write_header_intro(self, out):
543                 out.write("#ifndef MSP_GL_{}_\n".format(self.ext_name.upper()))
544                 out.write("#define MSP_GL_{}_\n".format(self.ext_name.upper()))
545
546                 out.write("""
547 #include <msp/gl/extension.h>
548 #include <msp/gl/gl.h>
549
550 namespace Msp {
551 namespace GL {
552
553 """)
554
555         def write_enum_definitions(self, out):
556                 enums_by_category = {}
557                 for e in self.enums:
558                         cat = None
559                         supp = e.api_support.get(self.host_api.name)
560                         if supp:
561                                 if supp.core_version:
562                                         cat = "{}_{}".format(self.api_prefix, supp.core_version.as_define())
563                                 elif supp.extensions:
564                                         cat = "GL_"+supp.extensions[0].name
565                         enums_by_category.setdefault(cat, []).append(e)
566
567                 for cat in sorted(enums_by_category.keys()):
568                         if cat:
569                                 out.write("#ifndef {}\n".format(cat))
570                         for e in enums_by_category[cat]:
571                                 out.write("#define {} 0x{:04X}\n".format(e.name, e.value))
572                         if cat:
573                                 out.write("#endif\n")
574                         out.write("\n")
575
576         def write_function_pointer_declarations(self, out):
577                 for f in self.funcs:
578                         typedef = self.func_typedefs[f.name]
579                         out.write("typedef {} (APIENTRY *{})({});\n".format(f.return_type, typedef, ", ".join(f.params)))
580                         out.write("extern {} {};\n".format(typedef, f.name))
581                         out.write("\n")
582
583         def write_header_outro(self, out):
584                 out.write("""
585 } // namespace GL
586 } // namespace Msp
587
588 #endif
589 """)
590
591         def write_source_intro(self, out):
592                 out.write("#include \"{}.h\"\n".format(self.ext_name.lower()))
593                 if self.funcs:
594                         out.write("""
595 #ifdef __APPLE__
596 #define GET_PROC_ADDRESS(x) &::x
597 #else
598 #define GET_PROC_ADDRESS(x) get_proc_address(#x)
599 #endif
600
601 #ifdef _WIN32
602 #define GET_PROC_ADDRESS_1_1(x) &::x
603 #else
604 #define GET_PROC_ADDRESS_1_1(x) GET_PROC_ADDRESS(x)
605 #endif
606 """)
607                 out.write("""
608 namespace Msp {
609 namespace GL {
610
611 """)
612
613         def write_function_pointer_definitions(self, out):
614                 for f in self.funcs:
615                         out.write("{} {} = 0;\n".format(self.func_typedefs[f.name], f.name))
616
617         def write_init_function(self, out):
618                 out.write("\nExtension::SupportLevel init_{}()\n{{\n".format(self.ext_name.lower()))
619                 if self.core_version:
620                         out.write("\tif(is_disabled(\"GL_{}\"))\n\t\treturn Extension::UNSUPPORTED;\n".format(self.ext_name))
621                         out.write("#if !defined(__APPLE__) || defined({}_{})\n".format(self.api_prefix, self.core_version.as_define()))
622                         out.write("\tif(")
623                         if self.backport_ext:
624                                 out.write("is_supported(\"GL_{}\") || ".format(self.backport_ext.name))
625                         out.write("is_supported({!r}".format(self.core_version))
626                         if self.deprecated_version:
627                                 out.write(", {!r}".format(self.deprecated_version))
628                         out.write("))\n\t{\n")
629                         for f in self.funcs:
630                                 supp = f.api_support.get(self.host_api.name)
631                                 if supp:
632                                         gpa_suffix = ""
633                                         if supp.core_version is not None and supp.core_version<=Version(1, 1):
634                                                 gpa_suffix = "_1_1"
635                                         out.write("\t\t{} = reinterpret_cast<{}>(GET_PROC_ADDRESS{}({}));\n".format(f.name, self.func_typedefs[f.name], gpa_suffix, f.name))
636                         out.write("\t\treturn Extension::CORE;\n")
637                         out.write("\t}\n")
638                         out.write("#endif\n")
639                 if self.source_exts:
640                         out.write("#if !defined(__APPLE__) || defined(GL_{})\n".format(self.ext_name))
641                         out.write("\tif(")
642                         if self.base_version:
643                                 out.write("is_supported({!r}) && ".format(self.base_version))
644                         out.write("{})\n\t{{\n".format(" && ".join("is_supported(\"GL_%s\")"%s.name for s in self.source_exts)))
645                         for f in self.funcs:
646                                 supp = f.api_support.get(self.host_api.name)
647                                 src = None
648                                 for e in self.source_exts:
649                                         if f.name in e.things:
650                                                 src = f
651                                         elif supp:
652                                                 for s in supp.sources:
653                                                         if s.name in e.things:
654                                                                 src = s
655                                                                 break
656                                         if src:
657                                                 break
658                                 if not src and supp and supp.core_version and self.base_version>=supp.core_version:
659                                         sec = f
660
661                                 if src:
662                                         out.write("\t\t{} = reinterpret_cast<{}>(GET_PROC_ADDRESS({}));\n".format(f.name, self.func_typedefs[f.name], src.name))
663                         out.write("\t\treturn Extension::EXTENSION;\n")
664                         out.write("\t}\n")
665                         out.write("#endif\n")
666                 out.write("\treturn Extension::UNSUPPORTED;\n")
667                 out.write("}\n")
668
669         def write_source_outro(self, out):
670                 out.write("""
671 } // namespace GL
672 } // namespace Msp
673 """)
674
675         def write_header(self, fn):
676                 out = open(fn, "w")
677                 self.write_header_intro(out)
678                 self.write_enum_definitions(out)
679                 self.write_function_pointer_declarations(out)
680                 out.write("extern Extension {};\n".format(self.ext_name))
681                 self.write_header_outro(out)
682
683         def write_source(self, fn):
684                 out = open(fn, "w")
685                 self.write_source_intro(out)
686                 self.write_function_pointer_definitions(out)
687                 self.write_init_function(out)
688                 out.write("\nExtension {}(\"GL_{}\", init_{});\n".format(self.ext_name, self.ext_name, self.ext_name.lower()))
689                 self.write_source_outro(out)
690
691
692 def dump_api_support(supp, api, indent):
693         if supp.core_version:
694                 print(indent+"core in version {}".format(supp.core_version))
695         if supp.deprecated_version:
696                 print(indent+"deprecated in version {}".format(supp.deprecated_version))
697         for e in supp.extensions:
698                 print(indent+"extension {} (preference {})".format(e.name, e.preference))
699         for r in supp.sources:
700                 print(indent+"source {}".format(r.name))
701                 dump_thing_info(r, api, indent+"  ")
702
703 def dump_thing_info(thing, api, indent):
704         for a in thing.aliases:
705                 print(indent+"alias {}".format(a))
706         if api:
707                 supp = thing.api_support.get(api)
708                 dump_api_support(supp, api, indent)
709         else:
710                 for a, s in thing.api_support.items():
711                         print(indent+"api {}".format(a))
712                         dump_api_support(s, a, indent+"  ")
713
714
715 class ExtensionParser:
716         def __init__(self, host_api):
717                 self.host_api = host_api
718                 self.target_ext = None
719                 self.core_version = None
720                 self.deprecated_version = None
721                 self.backport_ext = None
722                 self.source_exts = []
723                 self.ignore_things = []
724                 self.optional_things = []
725
726         def parse(self, fn):
727                 for line in open(fn):
728                         line = line.strip()
729                         if not line or line.startswith("#"):
730                                 continue
731
732                         parts = line.split()
733                         api = None
734                         keyword = parts[0]
735                         if ":" in keyword:
736                                 api, keyword = keyword.split(":")
737
738                         if api is not None and api!=self.host_api:
739                                 continue
740
741                         if keyword=="extension":
742                                 self.target_ext = parts[1]
743                         elif keyword=="core_version":
744                                 self.core_version = Version(*map(int, parts[1].split('.')))
745                         elif keyword=="deprecated":
746                                 self.deprecated_version = Version(*map(int, parts[1].split('.')))
747                         elif keyword=="backport":
748                                 self.backport_ext = parts[1]
749                         elif keyword=="source":
750                                 self.source_exts.append(parts[1])
751                         elif keyword=="ignore":
752                                 self.ignore_things.append(parts[1])
753                         elif keyword=="optional":
754                                 self.optional_things.append(parts[1])
755                         else:
756                                 print("Unknown keyword "+keyword)
757                                 return False
758
759                 return True
760
761
762 def get_extension(api_map, ext_name):
763         if "." in ext_name:
764                 ext_api_name, ext_name = ext_name.split(".")
765         else:
766                 ext_api_name = "gl"
767
768         return api_map[ext_api_name].extensions[ext_name]
769
770 def resolve_things(api, things):
771         rthings = []
772         for t in things:
773                 ct = [api.core_things[a] for a in t.aliases if a in api.core_things]
774                 if ct:
775                         rthings += ct
776                 else:
777                         rthings.append(t)
778
779         return rthings
780
781 def collect_extension_things(host_api, target_ext, ignore):
782         ext_things = [t for n, t in target_ext.things.items() if n not in ignore]
783         return resolve_things(target_ext.api, ext_things)
784
785 def collect_optional_things(target_ext, names):
786         things = []
787         for t in names:
788                 if t in target_ext.things:
789                         things.append(target_ext.things[t])
790                 else:
791                         things.append(target_ext.api.core_things[t])
792         return resolve_things(target_ext.api, things)
793
794 def main():
795         if len(sys.argv)<2:
796                 print("""Usage:
797   extgen.py [api] <extfile> [<outfile>]
798
799 Reads gl.xml and generates C++ source files to use an OpenGL extension
800 described in <extfile>.  If <outfile> is absent, the extension's lowercased
801 name is used.  Anything after the last dot in <outfile> is removed and
802 replaced with cpp and h.""")
803                 sys.exit(1)
804
805         i = 1
806
807         debug = False
808         if sys.argv[i]=="-g":
809                 debug = True
810                 i += 1
811
812         host_api_name = "gl"
813         if sys.argv[i].startswith("gl"):
814                 host_api_name = sys.argv[i]
815                 i += 1
816
817         ext_parser = ExtensionParser(host_api_name)
818         if not ext_parser.parse(sys.argv[i]):
819                 sys.exit(1)
820         i += 1
821
822         if i<len(sys.argv):
823                 out_base = os.path.splitext(sys.argv[i])[0]
824         else:
825                 out_base = ext_parser.target_ext.lower()
826
827         xml_parser = GlXmlParser()
828         xml_parser.parse_file("gl.xml")
829         xml_parser.parse_file("gl.fixes.xml")
830         xml_parser.parse_file("gl.msp.xml")
831         xml_parser.finalize()
832
833         host_api = xml_parser.apis[host_api_name]
834         target_ext = get_extension(xml_parser.apis, ext_parser.target_ext)
835         things = collect_extension_things(host_api, target_ext, ext_parser.ignore_things+ext_parser.optional_things)
836         optional_things = collect_optional_things(target_ext, ext_parser.optional_things)
837
838         if debug:
839                 print("--- Things included in this extension ---")
840                 all_things = things+optional_things
841                 all_things.sort(key=(lambda t: t.name))
842                 for t in all_things:
843                         print(t.name)
844                         if t in optional_things:
845                                 print("  optional")
846                         dump_thing_info(t, None, "  ")
847
848         generator = SourceGenerator(host_api, target_ext.name, things, optional_things, debug)
849         if ext_parser.core_version:
850                 generator.core_version = ext_parser.core_version
851         if ext_parser.deprecated_version:
852                 generator.deprecated_version = ext_parser.deprecated_version
853         if ext_parser.backport_ext:
854                 if ext_parser.backport_ext=="none":
855                         generator.backport_ext = None
856                 else:
857                         generator.backport_ext = get_extension(xml_parser.apis, ext_parser.backport_ext)
858         if ext_parser.source_exts:
859                 generator.base_version = None
860                 if len(ext_parser.source_exts)==1 and ext_parser.source_exts[0]=="none":
861                         generator.source_exts = []
862                 else:
863                         generator.source_exts = map((lambda e: get_extension(xml_parser.apis, e)), ext_parser.source_exts)
864         if debug:
865                 generator.dump_info()
866         generator.write_header(out_base+".h")
867         generator.write_source(out_base+".cpp")
868
869 if __name__=="__main__":
870         main()