]> git.tdb.fi Git - libs/gl.git/blobdiff - scripts/extgen.py
Make extensions compatible with OpenGL ES
[libs/gl.git] / scripts / extgen.py
index 4c3565e82cc0ac4ded23b39fbef6028d1978c27d..ff40ec60eadd1f515336f78d4d82eec2e0cfdfff 100755 (executable)
@@ -324,7 +324,7 @@ class GlXmlParser:
                self.sort_extensions()
 
 
-def detect_core_version(host_api, things):
+def detect_core_version(host_api, things, debug=None):
        max_version = Version(1, 0)
        max_count = 0
        lower_count = 0
@@ -347,11 +347,16 @@ def detect_core_version(host_api, things):
                print "Warning: Inconsistent core version %s"%max_version
 
        if missing:
+               if debug:
+                       print "---"
+                       print "%d things missing from core:"%len(missing)
+                       for t in missing:
+                               print "  "+t.name
                return None
 
        return max_version
 
-def detect_deprecated_version(host_api, things):
+def detect_deprecated_version(host_api, things, debug):
        min_version = None
        deprecated = []
        for t in things:
@@ -365,6 +370,11 @@ def detect_deprecated_version(host_api, things):
 
        if min_version and len(deprecated)*2<len(things):
                print "Warning: Inconsistent deprecation version %s"%min_version
+               if debug:
+                       print "---"
+                       print "%d things are deprecated:"%len(deprecated)
+                       for t in deprecated:
+                               print "  "+t.name
 
        return min_version
 
@@ -404,7 +414,7 @@ def collect_extensions(thing, api, exts):
        for s in supp.sources:
                collect_extensions(s, api, exts)
 
-def detect_source_extension(host_api, things):
+def detect_source_extension(host_api, things, debug=False):
        things_by_ext = {}
        for t in things:
                exts = []
@@ -412,9 +422,37 @@ def detect_source_extension(host_api, things):
                for e in exts:
                        things_by_ext.setdefault(e, []).append(t)
 
+       if debug:
+               print "---"
+               print "Looking for %d things in %d extensions"%(len(things), len(things_by_ext))
+
        extensions = []
+       keep_exts = 0
+       base_version = None
+       recheck_base_version = True
        missing = set(things)
-       while missing and things_by_ext:
+       while 1:
+               if recheck_base_version:
+                       max_version = Version(1, 0)
+                       for t in missing:
+                               supp = t.api_support.get(host_api.name)
+                               if supp and supp.core_version and max_version:
+                                       max_version = max(max_version, supp.core_version)
+                               else:
+                                       max_version = None
+
+                       if max_version:
+                               if not base_version or max_version<base_version:
+                                       base_version = max_version
+                                       keep_exts = len(extensions)
+                       elif not base_version:
+                               keep_exts = len(extensions)
+
+                       recheck_base_version = False
+
+               if not missing or not things_by_ext:
+                       break
+
                largest_ext = None
                largest_count = 0
                for e, t in things_by_ext.iteritems():
@@ -425,11 +463,16 @@ def detect_source_extension(host_api, things):
                        elif count==largest_count and e.preference>largest_ext.preference:
                                largest_ext = e
 
+               if debug:
+                       print "Found %d things in %s"%(largest_count, largest_ext.name)
+
                extensions.append(largest_ext)
                for t in things_by_ext[largest_ext]:
                        missing.remove(t)
-               if not missing:
-                       break
+
+                       supp = t.api_support.get(host_api.name)
+                       if supp and supp.core_version==base_version:
+                               recheck_base_version = True
 
                del things_by_ext[largest_ext]
                for e in things_by_ext.keys():
@@ -439,14 +482,25 @@ def detect_source_extension(host_api, things):
                        else:
                                del things_by_ext[e]
 
-       if missing:
-               return None
-
-       return extensions
+       if not missing:
+               return None, extensions
+       elif base_version:
+               if debug:
+                       print "Found remaining things in version %s"%base_version
+                       if keep_exts<len(extensions):
+                               print "Discarding %d extensions that do not improve base version"%(len(extensions)-keep_exts)
+               del extensions[keep_exts:]
+               return base_version, extensions
+       else:
+               if debug:
+                       print "%d things still missing:"%len(missing)
+                       for t in missing:
+                               print "  "+t.name
+               return None, None
 
 
 class SourceGenerator:
-       def __init__(self, host_api, ext_name, things, optional_things):
+       def __init__(self, host_api, ext_name, things, optional_things, debug=False):
                self.host_api = host_api
                self.api_prefix = "GL"
                if self.host_api.name=="gles2":
@@ -458,14 +512,29 @@ class SourceGenerator:
                self.func_typedefs = dict((f.name, "FPtr_"+f.name) for f in self.funcs)
                self.enums = filter((lambda t: t.kind==Thing.ENUM), all_things)
                self.enums.sort(key=(lambda e: e.value))
-               self.core_version = detect_core_version(host_api, things)
-               self.deprecated_version = detect_deprecated_version(host_api, things)
+               self.core_version = detect_core_version(host_api, things, debug)
+               self.deprecated_version = detect_deprecated_version(host_api, things, debug)
                self.backport_ext = detect_backport_extension(host_api, things);
-               self.source_exts = detect_source_extension(host_api, things)
+               b, e = detect_source_extension(host_api, things, debug)
+               self.base_version = b
+               self.source_exts = e
 
                if not self.core_version and not self.backport_ext and not self.source_exts:
                        print "Warning: Not supportable on host API"
 
+       def dump_info(self):
+               print "--- Extension information ---"
+               print "Extension %s"%self.ext_name
+               print "Core %s"%self.core_version
+               print "Deprecated %s"%self.deprecated_version
+               if self.backport_ext:
+                       print "Backport %s"%self.backport_ext.name
+               if self.source_exts:
+                       names = [e.name for e in self.source_exts]
+                       if self.base_version:
+                               names.insert(0, "Version %s"%self.base_version)
+                       print "Sources %s"%", ".join(names)
+
        def write_header_intro(self, out):
                out.write("#ifndef MSP_GL_%s_\n"%self.ext_name.upper())
                out.write("#define MSP_GL_%s_\n"%self.ext_name.upper())
@@ -554,30 +623,36 @@ namespace GL {
                                out.write(", %r"%self.deprecated_version)
                        out.write("))\n\t{\n")
                        for f in self.funcs:
-                               supp = f.api_support[self.host_api.name]
-                               gpa_suffix = ""
-                               if supp.core_version is not None and supp.core_version<=Version(1, 1):
-                                       gpa_suffix = "_1_1"
-                               out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS%s(%s));\n"%(f.name, self.func_typedefs[f.name], gpa_suffix, f.name))
+                               supp = f.api_support.get(self.host_api.name)
+                               if supp:
+                                       gpa_suffix = ""
+                                       if supp.core_version is not None and supp.core_version<=Version(1, 1):
+                                               gpa_suffix = "_1_1"
+                                       out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS%s(%s));\n"%(f.name, self.func_typedefs[f.name], gpa_suffix, f.name))
                        out.write("\t\treturn Extension::CORE;\n")
                        out.write("\t}\n")
                        out.write("#endif\n")
                if self.source_exts:
                        out.write("#if !defined(__APPLE__) || defined(GL_%s)\n"%self.ext_name)
-                       out.write("\tif(%s)\n\t{\n"%" && ".join("is_supported(\"GL_%s\")"%s.name for s in self.source_exts))
+                       out.write("\tif(")
+                       if self.base_version:
+                               out.write("is_supported(%r) && "%self.base_version)
+                       out.write("%s)\n\t{\n"%" && ".join("is_supported(\"GL_%s\")"%s.name for s in self.source_exts))
                        for f in self.funcs:
-                               supp = f.api_support[self.host_api.name]
-                               if supp.sources:
-                                       src = None
-                                       for e in self.source_exts:
+                               supp = f.api_support.get(self.host_api.name)
+                               src = None
+                               for e in self.source_exts:
+                                       if f in e.things:
+                                               src = f
+                                       elif supp:
                                                for s in supp.sources:
                                                        if s.name in e.things:
                                                                src = s
                                                                break
-                                               if src:
-                                                       break
-                               else:
-                                       src = f
+                                       if src:
+                                               break
+                               if not src and supp and supp.core_version and self.base_version>=supp.core_version:
+                                       sec = f
 
                                if src:
                                        out.write("\t\t%s = reinterpret_cast<%s>(GET_PROC_ADDRESS(%s));\n"%(f.name, self.func_typedefs[f.name], src.name))
@@ -610,6 +685,29 @@ namespace GL {
                self.write_source_outro(out)
 
 
+def dump_api_support(supp, api, indent):
+       if supp.core_version:
+               print indent+"core in version "+str(supp.core_version)
+       if supp.deprecated_version:
+               print indent+"deprecated in version "+str(supp.deprecated_version)
+       for e in supp.extensions:
+               print indent+"extension %s (preference %d)"%(e.name, e.preference)
+       for r in supp.sources:
+               print indent+"source "+r.name
+               dump_thing_info(r, api, indent+"  ")
+
+def dump_thing_info(thing, api, indent):
+       for a in thing.aliases:
+               print indent+"alias "+a
+       if api:
+               supp = thing.api_support.get(api)
+               dump_api_support(supp, api, indent)
+       else:
+               for a, s in thing.api_support.iteritems():
+                       print indent+"api "+a
+                       dump_api_support(s, a, indent+"  ")
+
+
 class ExtensionParser:
        def __init__(self, host_api):
                self.host_api = host_api
@@ -617,13 +715,14 @@ class ExtensionParser:
                self.core_version = None
                self.deprecated_version = None
                self.backport_ext = None
+               self.source_exts = []
                self.ignore_things = []
                self.optional_things = []
 
        def parse(self, fn):
                for line in open(fn):
                        line = line.strip()
-                       if line.startswith("#"):
+                       if not line or line.startswith("#"):
                                continue
 
                        parts = line.split()
@@ -643,6 +742,8 @@ class ExtensionParser:
                                self.deprecated_version = Version(*map(int, parts[1].split('.')))
                        elif keyword=="backport":
                                self.backport_ext = parts[1]
+                       elif keyword=="source":
+                               self.source_exts.append(parts[1])
                        elif keyword=="ignore":
                                self.ignore_things.append(parts[1])
                        elif keyword=="optional":
@@ -697,9 +798,14 @@ name is used.  Anything after the last dot in <outfile> is removed and
 replaced with cpp and h."""
                sys.exit(1)
 
-       host_api_name = "gl"
-
        i = 1
+
+       debug = False
+       if sys.argv[i]=="-g":
+               debug = True
+               i += 1
+
+       host_api_name = "gl"
        if sys.argv[i].startswith("gl"):
                host_api_name = sys.argv[i]
                i += 1
@@ -725,7 +831,17 @@ replaced with cpp and h."""
        things = collect_extension_things(host_api, target_ext, ext_parser.ignore_things+ext_parser.optional_things)
        optional_things = collect_optional_things(target_ext, ext_parser.optional_things)
 
-       generator = SourceGenerator(host_api, target_ext.name, things, optional_things)
+       if debug:
+               print "--- Things included in this extension ---"
+               all_things = things+optional_things
+               all_things.sort(key=(lambda t: t.name))
+               for t in all_things:
+                       print t.name
+                       if t in optional_things:
+                               print "  optional"
+                       dump_thing_info(t, None, "  ")
+
+       generator = SourceGenerator(host_api, target_ext.name, things, optional_things, debug)
        if ext_parser.core_version:
                generator.core_version = ext_parser.core_version
        if ext_parser.deprecated_version:
@@ -735,6 +851,14 @@ replaced with cpp and h."""
                        generator.backport_ext = None
                else:
                        generator.backport_ext = get_extension(xml_parser.apis, ext_parser.backport_ext)
+       if ext_parser.source_exts:
+               generator.base_version = None
+               if len(ext_parser.source_exts)==1 and ext_parser.source_exts[0]=="none":
+                       generator.source_exts = []
+               else:
+                       generator.source_exts = map((lambda e: get_extension(xml_parser.apis, e)), ext_parser.source_exts)
+       if debug:
+               generator.dump_info()
        generator.write_header(out_base+".h")
        generator.write_source(out_base+".cpp")