]> git.tdb.fi Git - libs/gl.git/blob - scripts/extgen.py
Add a class to unify loading coordinate transforms
[libs/gl.git] / scripts / extgen.py
1 #!/usr/bin/python
2
3 import sys
4 import os
5 import xml.dom
6 import xml.dom.minidom
7 import itertools
8 import re
9
10 class Version:
11         def __init__(self, *args):
12                 if len(args)==0:
13                         self.major = 0
14                         self.minor = 0
15                 elif len(args)==2:
16                         self.major = args[0]
17                         self.minor = args[1]
18                 else:
19                         raise TypeError, "__init__() takes zero or two arguments (%d given)"%len(args)
20
21         def __str__(self):
22                 return "%d.%d"%(self.major, self.minor)
23
24         def __repr__(self):
25                 return "Version(%d, %d)"%(self.major, self.minor)
26
27         def as_define(self):
28                 return "VERSION_%d_%d"%(self.major, self.minor)
29
30         def __lt__(self, other):
31                 if other is None:
32                         return False
33
34                 if self.major!=other.major:
35                         return self.major<other.major
36                 return self.minor<other.minor
37
38         def __gt__(self, other):
39                 if other is None:
40                         return True
41
42                 if self.major!=other.major:
43                         return self.major>other.major
44                 return self.minor>other.minor
45
46         def __eq__(self, other):
47                 if other is None:
48                         return False
49
50                 return (self.major==other.major and self.minor==other.minor)
51
52
53 class Thing:
54         FUNCTION = 1
55         ENUM = 2
56
57         class ApiSupport:
58                 def __init__(self):
59                         self.core_version = None
60                         self.deprecated_version = None
61                         self.extensions = []
62                         self.sources = []
63
64         def __init__(self, name, kind):
65                 self.name = name
66                 self.kind = kind
67                 self.aliases = []
68                 self.api_support = {}
69
70         def get_or_create_api_support(self, api):
71                 supp = self.api_support.get(api)
72                 if not supp:
73                         supp = Thing.ApiSupport()
74                         self.api_support[api] = supp
75                 return supp
76
77
78 class Function(Thing):
79         def __init__(self, name):
80                 Thing.__init__(self, name, Thing.FUNCTION)
81                 self.return_type = "void"
82                 self.params = []
83
84
85 r_bitmask = re.compile("_BIT[0-9]*(_|$)")
86
87 class Enum(Thing):
88         def __init__(self, name):
89                 Thing.__init__(self, name, Thing.ENUM)
90                 self.value = 0
91                 self.bitmask = bool(r_bitmask.search(self.name))
92
93
94 class Extension:
95         def __init__(self, name, api):
96                 self.name = name
97                 underscore = name.find('_')
98                 self.ext_type = name[0:underscore]
99                 self.base_name = name[underscore+1:]
100                 self.api = api
101                 self.things = {}
102                 self.preference = 0
103                 if self.ext_type=="EXT":
104                         self.preference = 1
105                 elif self.ext_type=="ARB" or self.ext_type=="OES":
106                         self.preference = 2
107                 self.backport = False
108
109
110 class Api:
111         def __init__(self, name):
112                 self.name = name
113                 self.latest_version = None
114                 self.core_things = {}
115                 self.extensions = {}
116
117
118 def get_nested_elements(elem, path):
119         childElements = [c for c in elem.childNodes if c.nodeType==xml.dom.Node.ELEMENT_NODE]
120         if '/' in path:
121                 head, tail = path.split('/', 1)
122                 result = []
123                 for c in childElements:
124                         if c.tagName==head:
125                                 result += get_nested_elements(c, tail)
126                 return result
127         else:
128                 return [c for c in childElements if c.tagName==path]
129
130 def get_first_child(elem, tag):
131         for c in elem.childNodes:
132                 if c.nodeType==xml.dom.Node.ELEMENT_NODE and c.tagName==tag:
133                         return c
134         return None
135
136 def get_text_contents(node):
137         result = ""
138         for c in node.childNodes:
139                 if c.nodeType==xml.dom.Node.TEXT_NODE or c.nodeType==xml.dom.Node.CDATA_SECTION_NODE:
140                         result += c.data
141                 else:
142                         result += get_text_contents(c)
143         return result
144
145 def get_or_create(map, name, type, *args):
146         obj = map.get(name)
147         if not obj:
148                 obj = type(name, *args)
149                 map[name] = obj
150         return obj
151
152
153 class GlXmlParser:
154         def __init__(self):
155                 self.apis = {}
156                 self.things = {}
157
158         def parse_command(self, cmd):
159                 proto = get_first_child(cmd, "proto")
160                 name = get_text_contents(get_first_child(proto, "name"))
161                 func = get_or_create(self.things, name, Function)
162
163                 aliases = get_nested_elements(cmd, "alias")
164                 func.aliases = [a.getAttribute("name") for a in aliases]
165
166                 ptype = get_first_child(proto, "ptype")
167                 if ptype:
168                         func.return_type = get_text_contents(ptype)
169                 else:
170                         for c in proto.childNodes:
171                                 if c.nodeType==xml.dom.Node.TEXT_NODE and c.data.strip():
172                                         func.return_type = c.data.strip()
173                                         break
174
175                 params = get_nested_elements(cmd, "param")
176                 func.params = map(get_text_contents, params)
177
178         def parse_enum(self, en):
179                 name = en.getAttribute("name")
180                 enum = get_or_create(self.things, name, Enum)
181
182                 enum.value = int(en.getAttribute("value"), 16)
183
184                 alias = en.getAttribute("alias")
185                 if alias:
186                         enum.aliases.append(alias)
187
188         def parse_feature(self, feat):
189                 api_name = feat.getAttribute("api")
190                 api = get_or_create(self.apis, api_name, Api)
191
192                 version = feat.getAttribute("number")
193                 if version:
194                         version = Version(*map(int, version.split('.')))
195                 else:
196                         version = None
197
198                 requires = get_nested_elements(feat, "require")
199                 for req in requires:
200                         commands = get_nested_elements(req, "command")
201                         enums = get_nested_elements(req, "enum")
202                         for t in itertools.chain(commands, enums):
203                                 name = t.getAttribute("name")
204                                 thing = self.things.get(name)
205                                 if thing:
206                                         supp = thing.get_or_create_api_support(api.name)
207                                         if not supp.core_version or version<supp.core_version:
208                                                 supp.core_version = version
209                                         api.core_things[thing.name] = thing
210
211                 removes = get_nested_elements(feat, "remove")
212                 for rem in removes:
213                         commands = get_nested_elements(rem, "command")
214                         enums = get_nested_elements(rem, "enum")
215
216                         for t in itertools.chain(commands, enums):
217                                 name = t.getAttribute("name")
218                                 thing = self.things.get(name)
219                                 if thing:
220                                         supp = thing.get_or_create_api_support(api.name)
221                                         supp.deprecated_version = version
222
223         def parse_extension(self, ext):
224                 ext_things_by_api = {}
225                 requires = get_nested_elements(ext, "require")
226                 for req in requires:
227                         api = req.getAttribute("api")
228                         ext_things = ext_things_by_api.setdefault(api, [])
229
230                         commands = get_nested_elements(req, "command")
231                         enums = get_nested_elements(req, "enum")
232                         for t in itertools.chain(commands, enums):
233                                 name = t.getAttribute("name")
234                                 thing = self.things.get(name)
235                                 if thing:
236                                         ext_things.append(thing)
237
238                 ext_name = ext.getAttribute("name")
239                 if ext_name.startswith("GL_"):
240                         ext_name = ext_name[3:]
241
242                 common_things = ext_things_by_api.get("", [])
243                 supported = ext.getAttribute("supported").split('|')
244                 for s in supported:
245                         api = self.apis.get(s)
246                         if not api:
247                                 continue
248
249                         ext = get_or_create(api.extensions, ext_name, Extension, api)
250                         api_things = ext_things_by_api.get(s, [])
251                         for t in itertools.chain(common_things, api_things):
252                                 ext.things[t.name] = t
253                                 t.get_or_create_api_support(api.name).extensions.append(ext)
254
255         def parse_file(self, fn):
256                 doc = xml.dom.minidom.parse(fn)
257                 root = doc.documentElement
258
259                 commands = get_nested_elements(root, "commands/command")
260                 for cmd in commands:
261                         self.parse_command(cmd)
262
263                 enums = get_nested_elements(root, "enums/enum")
264                 for en in enums:
265                         self.parse_enum(en)
266
267                 features = get_nested_elements(root, "feature")
268                 for feat in features:
269                         self.parse_feature(feat)
270
271                 extensions = get_nested_elements(root, "extensions/extension")
272                 for ext in extensions:
273                         self.parse_extension(ext)
274
275         def check_backport_extensions(self, api):
276                 for e in api.extensions.itervalues():
277                         if e.ext_type!="ARB":
278                                 continue
279
280                         e.backport = True
281                         for t in e.things.itervalues():
282                                 if t.name.endswith(e.ext_type):
283                                         e.backport = False
284                                         break
285
286         def resolve_enum_aliases(self, api):
287                 for e in api.extensions.itervalues():
288                         ext_enums = filter((lambda t: t.kind==Thing.ENUM), e.things.itervalues())
289                         enum_suffix = "_"+e.ext_type
290                         for n in ext_enums:
291                                 if n.api_support[api.name].core_version:
292                                         continue
293
294                                 name = n.name
295                                 if name.endswith(enum_suffix):
296                                         name = name[:-len(enum_suffix)]
297                                 ce = api.core_things.get(name)
298                                 if ce and ce.value==n.value and ce.name not in n.aliases:
299                                         n.aliases.append(ce.name)
300
301         def resolve_sources(self, api):
302                 for e in api.extensions.itervalues():
303                         for t in e.things.itervalues():
304                                 for a in t.aliases:
305                                         # There are a few cases where a vendor function is aliased to
306                                         # an EXT or ARB function but those are rare and not relevant for
307                                         # our use
308                                         alias = api.core_things.get(a)
309                                         if alias:
310                                                 sources = alias.api_support[api.name].sources
311                                                 if t not in sources:
312                                                         sources.append(t)
313
314         def sort_extensions(self):
315                 for t in self.things.itervalues():
316                         for s in t.api_support.itervalues():
317                                 s.extensions.sort(key=(lambda e: e.preference), reverse=True)
318
319         def finalize(self):
320                 for a in self.apis.itervalues():
321                         self.check_backport_extensions(a)
322                         self.resolve_enum_aliases(a)
323                         self.resolve_sources(a)
324                 self.sort_extensions()
325
326
327 def detect_core_version(host_api, things, debug=None):
328         max_version = Version(1, 0)
329         max_count = 0
330         lower_count = 0
331         missing = []
332         for t in things:
333                 supp = t.api_support.get(host_api.name)
334                 if supp and supp.core_version:
335                         if supp.core_version>max_version:
336                                 max_version = supp.core_version
337                                 lower_count += max_count
338                                 max_count = 1
339                         elif supp.core_version==max_version:
340                                 max_count += 1
341                         else:
342                                 lower_count += 1
343                 else:
344                         missing.append(t)
345
346         if lower_count>max_count or (missing and len(missing)*2<lower_count+max_count):
347                 print "Warning: Inconsistent core version %s"%max_version
348
349         if missing:
350                 if debug:
351                         print "---"
352                         print "%d things missing from core:"%len(missing)
353                         for t in missing:
354                                 print "  "+t.name
355                 return None
356
357         return max_version
358
359 def detect_deprecated_version(host_api, things, debug):
360         min_version = None
361         deprecated = []
362         for t in things:
363                 supp = t.api_support.get(host_api.name)
364                 if supp and supp.deprecated_version:
365                         if min_version is None:
366                                 min_version = supp.deprecated_version
367                         else:
368                                 min_version = min(min_version, supp.deprecated_version)
369                         deprecated.append(t)
370
371         if min_version and len(deprecated)*2<len(things):
372                 print "Warning: Inconsistent deprecation version %s"%min_version
373                 if debug:
374                         print "---"
375                         print "%d things are deprecated:"%len(deprecated)
376                         for t in deprecated:
377                                 print "  "+t.name
378
379         return min_version
380
381 def detect_backport_extension(host_api, things):
382         candidates = []
383         for t in things:
384                 supp = t.api_support.get(host_api.name)
385                 if supp and supp.core_version:
386                         for e in supp.extensions:
387                                 if e.backport and e not in candidates:
388                                         candidates.append(e)
389
390         total_count = len(things)
391         best_ext = None
392         best_count = 0
393         for e in candidates:
394                 things_in_ext = filter((lambda t: t.name in e.things), things)
395                 count = len(things_in_ext)
396                 if count==total_count:
397                         return e
398                 elif count>best_count:
399                         best_ext = e
400                         best_count = count
401
402         if best_count*2>=total_count:
403                 print "Warning: Inconsistent backport extension %s"%best_ext.name
404
405 def collect_extensions(thing, api, exts):
406         supp = thing.api_support.get(api)
407         if not supp:
408                 return
409
410         for e in supp.extensions:
411                 if not e.backport and e.ext_type!="MSP" and e not in exts:
412                         exts.append(e)
413
414         for s in supp.sources:
415                 collect_extensions(s, api, exts)
416
417 def detect_source_extension(host_api, things, debug=False):
418         things_by_ext = {}
419         for t in things:
420                 exts = []
421                 collect_extensions(t, host_api.name, exts)
422                 for e in exts:
423                         things_by_ext.setdefault(e, []).append(t)
424
425         if debug:
426                 print "---"
427                 print "Looking for %d things in %d extensions"%(len(things), len(things_by_ext))
428
429         extensions = []
430         keep_exts = 0
431         base_version = None
432         recheck_base_version = True
433         missing = set(things)
434         while 1:
435                 if recheck_base_version:
436                         max_version = Version(1, 0)
437                         for t in missing:
438                                 supp = t.api_support.get(host_api.name)
439                                 if supp and supp.core_version and max_version:
440                                         max_version = max(max_version, supp.core_version)
441                                 else:
442                                         max_version = None
443
444                         if max_version:
445                                 if not base_version or max_version<base_version:
446                                         base_version = max_version
447                                         keep_exts = len(extensions)
448                         elif not base_version:
449                                 keep_exts = len(extensions)
450
451                         recheck_base_version = False
452
453                 if not missing or not things_by_ext:
454                         break
455
456                 largest_ext = None
457                 largest_count = 0
458                 for e, t in things_by_ext.iteritems():
459                         count = len(t)
460                         if count>largest_count:
461                                 largest_ext = e
462                                 largest_count = count
463                         elif count==largest_count and e.preference>largest_ext.preference:
464                                 largest_ext = e
465
466                 if debug:
467                         print "Found %d things in %s"%(largest_count, largest_ext.name)
468
469                 extensions.append(largest_ext)
470                 for t in things_by_ext[largest_ext]:
471                         missing.remove(t)
472
473                         supp = t.api_support.get(host_api.name)
474                         if supp and supp.core_version==base_version:
475                                 recheck_base_version = True
476
477                 del things_by_ext[largest_ext]
478                 for e in things_by_ext.keys():
479                         unseen = filter((lambda t: t in missing), things_by_ext[e])
480                         if unseen:
481                                 things_by_ext[e] = unseen
482                         else:
483                                 del things_by_ext[e]
484
485         if not missing:
486                 return None, extensions
487         elif base_version:
488                 if debug:
489                         print "Found remaining things in version %s"%base_version
490                         if keep_exts<len(extensions):
491                                 print "Discarding %d extensions that do not improve base version"%(len(extensions)-keep_exts)
492                 del extensions[keep_exts:]
493                 return base_version, extensions
494         else:
495                 if debug:
496                         print "%d things still missing:"%len(missing)
497                         for t in missing:
498                                 print "  "+t.name
499                 return None, None
500
501
502 class SourceGenerator:
503         def __init__(self, host_api, ext_name, things, optional_things, debug=False):
504                 self.host_api = host_api
505                 self.api_prefix = "GL"
506                 if self.host_api.name=="gles2":
507                         self.api_prefix = "GL_ES"
508                 self.ext_name = ext_name
509                 all_things = things+optional_things
510                 self.funcs = filter((lambda t: t.kind==Thing.FUNCTION), all_things)
511                 self.funcs.sort(key=(lambda f: f.name))
512                 self.func_typedefs = dict((f.name, "FPtr_"+f.name) for f in self.funcs)
513                 self.enums = filter((lambda t: t.kind==Thing.ENUM), all_things)
514                 self.enums.sort(key=(lambda e: e.value))
515                 self.core_version = detect_core_version(host_api, things, debug)
516                 self.deprecated_version = detect_deprecated_version(host_api, things, debug)
517                 self.backport_ext = detect_backport_extension(host_api, things)
518                 b, e = detect_source_extension(host_api, things, debug)
519                 self.base_version = b
520                 self.source_exts = e
521
522                 if not self.core_version and not self.backport_ext and not self.source_exts:
523                         print "Warning: Not supportable on host API"
524
525         def dump_info(self):
526                 print "--- Extension information ---"
527                 print "Extension %s"%self.ext_name
528                 print "Core %s"%self.core_version
529                 print "Deprecated %s"%self.deprecated_version
530                 if self.backport_ext:
531                         print "Backport %s"%self.backport_ext.name
532                 if self.source_exts:
533                         names = [e.name for e in self.source_exts]
534                         if self.base_version:
535                                 names.insert(0, "Version %s"%self.base_version)
536                         print "Sources %s"%", ".join(names)
537
538         def write_header_intro(self, out):
539                 out.write("#ifndef MSP_GL_%s_\n"%self.ext_name.upper())
540                 out.write("#define MSP_GL_%s_\n"%self.ext_name.upper())
541
542                 out.write("""
543 #include <msp/gl/extension.h>
544 #include <msp/gl/gl.h>
545
546 namespace Msp {
547 namespace GL {
548
549 """)
550
551         def write_enum_definitions(self, out):
552                 enums_by_category = {}
553                 for e in self.enums:
554                         cat = None
555                         supp = e.api_support.get(self.host_api.name)
556                         if supp:
557                                 if supp.core_version:
558                                         cat = "%s_%s"%(self.api_prefix, supp.core_version.as_define())
559                                 elif supp.extensions:
560                                         cat = "GL_"+supp.extensions[0].name
561                         enums_by_category.setdefault(cat, []).append(e)
562
563                 for cat in sorted(enums_by_category.keys()):
564                         if cat:
565                                 out.write("#ifndef %s\n"%cat)
566                         for e in enums_by_category[cat]:
567                                 out.write("#define %s 0x%04X\n"%(e.name, e.value))
568                         if cat:
569                                 out.write("#endif\n")
570                         out.write("\n")
571
572         def write_function_pointer_declarations(self, out):
573                 for f in self.funcs:
574                         typedef = self.func_typedefs[f.name]
575                         out.write("typedef %s (APIENTRY *%s)(%s);\n"%(f.return_type, typedef, ", ".join(f.params)))
576                         out.write("extern %s %s;\n"%(typedef, f.name))
577                         out.write("\n")
578
579         def write_header_outro(self, out):
580                 out.write("""
581 } // namespace GL
582 } // namespace Msp
583
584 #endif
585 """)
586
587         def write_source_intro(self, out):
588                 out.write("#include \"%s.h\"\n"%self.ext_name.lower())
589                 if self.funcs:
590                         out.write("""
591 #ifdef __APPLE__
592 #define GET_PROC_ADDRESS(x) &::x
593 #else
594 #define GET_PROC_ADDRESS(x) get_proc_address(#x)
595 #endif
596
597 #ifdef _WIN32
598 #define GET_PROC_ADDRESS_1_1(x) &::x
599 #else
600 #define GET_PROC_ADDRESS_1_1(x) GET_PROC_ADDRESS(x)
601 #endif
602 """)
603                 out.write("""
604 namespace Msp {
605 namespace GL {
606
607 """)
608
609         def write_function_pointer_definitions(self, out):
610                 for f in self.funcs:
611                         out.write("%s %s = 0;\n"%(self.func_typedefs[f.name], f.name))
612
613         def write_init_function(self, out):
614                 out.write("\nExtension::SupportLevel init_%s()\n{\n"%self.ext_name.lower())
615                 if self.core_version:
616                         out.write("\tif(is_disabled(\"GL_%s\"))\n\t\treturn Extension::UNSUPPORTED;\n"%self.ext_name)
617                         out.write("#if !defined(__APPLE__) || defined(%s_%s)\n"%(self.api_prefix, self.core_version.as_define()))
618                         out.write("\tif(")
619                         if self.backport_ext:
620                                 out.write("is_supported(\"GL_%s\") || "%self.backport_ext.name)
621                         out.write("is_supported(%r"%self.core_version)
622                         if self.deprecated_version:
623                                 out.write(", %r"%self.deprecated_version)
624                         out.write("))\n\t{\n")
625                         for f in self.funcs:
626                                 supp = f.api_support.get(self.host_api.name)
627                                 if supp:
628                                         gpa_suffix = ""
629                                         if supp.core_version is not None and supp.core_version<=Version(1, 1):
630                                                 gpa_suffix = "_1_1"
631                                         out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS%s(%s));\n"%(f.name, self.func_typedefs[f.name], gpa_suffix, f.name))
632                         out.write("\t\treturn Extension::CORE;\n")
633                         out.write("\t}\n")
634                         out.write("#endif\n")
635                 if self.source_exts:
636                         out.write("#if !defined(__APPLE__) || defined(GL_%s)\n"%self.ext_name)
637                         out.write("\tif(")
638                         if self.base_version:
639                                 out.write("is_supported(%r) && "%self.base_version)
640                         out.write("%s)\n\t{\n"%" && ".join("is_supported(\"GL_%s\")"%s.name for s in self.source_exts))
641                         for f in self.funcs:
642                                 supp = f.api_support.get(self.host_api.name)
643                                 src = None
644                                 for e in self.source_exts:
645                                         if f.name in e.things:
646                                                 src = f
647                                         elif supp:
648                                                 for s in supp.sources:
649                                                         if s.name in e.things:
650                                                                 src = s
651                                                                 break
652                                         if src:
653                                                 break
654                                 if not src and supp and supp.core_version and self.base_version>=supp.core_version:
655                                         sec = f
656
657                                 if src:
658                                         out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS(%s));\n"%(f.name, self.func_typedefs[f.name], src.name))
659                         out.write("\t\treturn Extension::EXTENSION;\n")
660                         out.write("\t}\n")
661                         out.write("#endif\n")
662                 out.write("\treturn Extension::UNSUPPORTED;\n")
663                 out.write("}\n")
664
665         def write_source_outro(self, out):
666                 out.write("""
667 } // namespace GL
668 } // namespace Msp
669 """)
670
671         def write_header(self, fn):
672                 out = file(fn, "w")
673                 self.write_header_intro(out)
674                 self.write_enum_definitions(out)
675                 self.write_function_pointer_declarations(out)
676                 out.write("extern Extension %s;\n"%self.ext_name)
677                 self.write_header_outro(out)
678
679         def write_source(self, fn):
680                 out = file(fn, "w")
681                 self.write_source_intro(out)
682                 self.write_function_pointer_definitions(out)
683                 self.write_init_function(out)
684                 out.write("\nExtension %s(\"GL_%s\", init_%s);\n"%(self.ext_name, self.ext_name, self.ext_name.lower()))
685                 self.write_source_outro(out)
686
687
688 def dump_api_support(supp, api, indent):
689         if supp.core_version:
690                 print indent+"core in version "+str(supp.core_version)
691         if supp.deprecated_version:
692                 print indent+"deprecated in version "+str(supp.deprecated_version)
693         for e in supp.extensions:
694                 print indent+"extension %s (preference %d)"%(e.name, e.preference)
695         for r in supp.sources:
696                 print indent+"source "+r.name
697                 dump_thing_info(r, api, indent+"  ")
698
699 def dump_thing_info(thing, api, indent):
700         for a in thing.aliases:
701                 print indent+"alias "+a
702         if api:
703                 supp = thing.api_support.get(api)
704                 dump_api_support(supp, api, indent)
705         else:
706                 for a, s in thing.api_support.iteritems():
707                         print indent+"api "+a
708                         dump_api_support(s, a, indent+"  ")
709
710
711 class ExtensionParser:
712         def __init__(self, host_api):
713                 self.host_api = host_api
714                 self.target_ext = None
715                 self.core_version = None
716                 self.deprecated_version = None
717                 self.backport_ext = None
718                 self.source_exts = []
719                 self.ignore_things = []
720                 self.optional_things = []
721
722         def parse(self, fn):
723                 for line in open(fn):
724                         line = line.strip()
725                         if not line or line.startswith("#"):
726                                 continue
727
728                         parts = line.split()
729                         api = None
730                         keyword = parts[0]
731                         if ":" in keyword:
732                                 api, keyword = keyword.split(":")
733
734                         if api is not None and api!=self.host_api:
735                                 continue
736
737                         if keyword=="extension":
738                                 self.target_ext = parts[1]
739                         elif keyword=="core_version":
740                                 self.core_version = Version(*map(int, parts[1].split('.')))
741                         elif keyword=="deprecated":
742                                 self.deprecated_version = Version(*map(int, parts[1].split('.')))
743                         elif keyword=="backport":
744                                 self.backport_ext = parts[1]
745                         elif keyword=="source":
746                                 self.source_exts.append(parts[1])
747                         elif keyword=="ignore":
748                                 self.ignore_things.append(parts[1])
749                         elif keyword=="optional":
750                                 self.optional_things.append(parts[1])
751                         else:
752                                 print "Unknown keyword "+keyword
753                                 return False
754
755                 return True
756
757
758 def get_extension(api_map, ext_name):
759         if "." in ext_name:
760                 ext_api_name, ext_name = ext_name.split(".")
761         else:
762                 ext_api_name = "gl"
763
764         return api_map[ext_api_name].extensions[ext_name]
765
766 def resolve_things(api, things):
767         rthings = []
768         for t in things:
769                 ct = filter(None, map(api.core_things.get, t.aliases))
770                 if ct:
771                         rthings += ct
772                 else:
773                         rthings.append(t)
774
775         return rthings
776
777 def collect_extension_things(host_api, target_ext, ignore):
778         ext_things = [t for n, t in target_ext.things.iteritems() if n not in ignore]
779         return resolve_things(target_ext.api, ext_things)
780
781 def collect_optional_things(target_ext, names):
782         things = []
783         for t in names:
784                 if t in target_ext.things:
785                         things.append(target_ext.things[t])
786                 else:
787                         things.append(target_ext.api.core_things[t])
788         return resolve_things(target_ext.api, things)
789
790 def main():
791         if len(sys.argv)<2:
792                 print """Usage:
793   extgen.py [api] <extfile> [<outfile>]
794
795 Reads gl.xml and generates C++ source files to use an OpenGL extension
796 described in <extfile>.  If <outfile> is absent, the extension's lowercased
797 name is used.  Anything after the last dot in <outfile> is removed and
798 replaced with cpp and h."""
799                 sys.exit(1)
800
801         i = 1
802
803         debug = False
804         if sys.argv[i]=="-g":
805                 debug = True
806                 i += 1
807
808         host_api_name = "gl"
809         if sys.argv[i].startswith("gl"):
810                 host_api_name = sys.argv[i]
811                 i += 1
812
813         ext_parser = ExtensionParser(host_api_name)
814         if not ext_parser.parse(sys.argv[i]):
815                 sys.exit(1)
816         i += 1
817
818         if i<len(sys.argv):
819                 out_base = os.path.splitext(sys.argv[i])[0]
820         else:
821                 out_base = ext_parser.target_ext.lower()
822
823         xml_parser = GlXmlParser()
824         xml_parser.parse_file("gl.xml")
825         xml_parser.parse_file("gl.fixes.xml")
826         xml_parser.parse_file("gl.msp.xml")
827         xml_parser.finalize()
828
829         host_api = xml_parser.apis[host_api_name]
830         target_ext = get_extension(xml_parser.apis, ext_parser.target_ext)
831         things = collect_extension_things(host_api, target_ext, ext_parser.ignore_things+ext_parser.optional_things)
832         optional_things = collect_optional_things(target_ext, ext_parser.optional_things)
833
834         if debug:
835                 print "--- Things included in this extension ---"
836                 all_things = things+optional_things
837                 all_things.sort(key=(lambda t: t.name))
838                 for t in all_things:
839                         print t.name
840                         if t in optional_things:
841                                 print "  optional"
842                         dump_thing_info(t, None, "  ")
843
844         generator = SourceGenerator(host_api, target_ext.name, things, optional_things, debug)
845         if ext_parser.core_version:
846                 generator.core_version = ext_parser.core_version
847         if ext_parser.deprecated_version:
848                 generator.deprecated_version = ext_parser.deprecated_version
849         if ext_parser.backport_ext:
850                 if ext_parser.backport_ext=="none":
851                         generator.backport_ext = None
852                 else:
853                         generator.backport_ext = get_extension(xml_parser.apis, ext_parser.backport_ext)
854         if ext_parser.source_exts:
855                 generator.base_version = None
856                 if len(ext_parser.source_exts)==1 and ext_parser.source_exts[0]=="none":
857                         generator.source_exts = []
858                 else:
859                         generator.source_exts = map((lambda e: get_extension(xml_parser.apis, e)), ext_parser.source_exts)
860         if debug:
861                 generator.dump_info()
862         generator.write_header(out_base+".h")
863         generator.write_source(out_base+".cpp")
864
865 if __name__=="__main__":
866         main()