]> git.tdb.fi Git - libs/gl.git/commitdiff
Rewrite syntax tree modifications
authorMikko Rasa <tdb@tdb.fi>
Sun, 21 Feb 2021 02:20:32 +0000 (04:20 +0200)
committerMikko Rasa <tdb@tdb.fi>
Sun, 21 Feb 2021 03:02:28 +0000 (05:02 +0200)
BlockModifier left dangling pointers in various variable maps, which
started causing trouble when I wanted to print out the entire AST.
NodeRemover is now used for all removal operations and it properly
handles the maps.

source/glsl/compatibility.cpp
source/glsl/compatibility.h
source/glsl/generate.cpp
source/glsl/generate.h
source/glsl/optimize.cpp
source/glsl/optimize.h
source/glsl/visitor.cpp
source/glsl/visitor.h

index 692db972b56491eca5de797ae0f79bec61995754..a9f38a35cda42481935864b5bbd97285bbf63462 100644 (file)
@@ -14,25 +14,23 @@ namespace GL {
 namespace SL {
 
 DefaultPrecisionGenerator::DefaultPrecisionGenerator():
-       stage_type(Stage::SHARED),
-       toplevel(true)
+       stage(0)
 { }
 
-void DefaultPrecisionGenerator::apply(Stage &stage)
+void DefaultPrecisionGenerator::apply(Stage &s)
 {
-       stage_type = stage.type;
-       visit(stage.content);
+       stage = &s;
+       visit(s.content);
 }
 
 void DefaultPrecisionGenerator::visit(Block &block)
 {
-       if(toplevel)
+       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
        {
-               SetForScope<bool> set(toplevel, false);
-               BlockModifier::visit(block);
+               if(&block==&stage->content)
+                       insert_point = i;
+               (*i)->visit(*this);
        }
-       else
-               TraversingVisitor::visit(block);
 }
 
 void DefaultPrecisionGenerator::visit(Precision &prec)
@@ -56,21 +54,27 @@ void DefaultPrecisionGenerator::visit(VariableDeclaration &var)
                Precision *prec = new Precision;
                if(!type.compare(0, 7, "sampler"))
                        prec->precision = "lowp";
-               else if(stage_type==Stage::FRAGMENT)
+               else if(stage->type==Stage::FRAGMENT)
                        prec->precision = "mediump";
                else
                        prec->precision = "highp";
                prec->type = type;
-               insert_nodes.push_back(prec);
+               stage->content.body.insert(insert_point, prec);
 
                have_default.insert(type);
        }
 }
 
 
-void PrecisionRemover::visit(Precision &)
+void PrecisionRemover::apply(Stage &stage)
+{
+       visit(stage.content);
+       NodeRemover().apply(stage, nodes_to_remove);
+}
+
+void PrecisionRemover::visit(Precision &prec)
 {
-       remove_node = true;
+       nodes_to_remove.insert(&prec);
 }
 
 void PrecisionRemover::visit(VariableDeclaration &var)
@@ -91,6 +95,16 @@ void LegacyConverter::apply(Stage &s)
        visit(s.content);
 }
 
+void LegacyConverter::visit(Block &block)
+{
+       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
+       {
+               if(&block==&stage->content)
+                       uniform_insert_point = i;
+               (*i)->visit(*this);
+       }
+}
+
 bool LegacyConverter::check_version(const Version &feature_version) const
 {
        if(target_version<feature_version)
@@ -279,7 +293,7 @@ void LegacyConverter::visit(VariableDeclaration &var)
                if(stage->type==Stage::FRAGMENT && var.interface=="out")
                {
                        frag_out = &var;
-                       remove_node = true;
+                       nodes_to_remove.insert(&var);
                }
        }
 
@@ -306,7 +320,10 @@ bool LegacyConverter::supports_interface_blocks(const string &iface) const
 void LegacyConverter::visit(InterfaceBlock &iface)
 {
        if(!supports_interface_blocks(iface.interface))
-               flatten_block(iface.members);
+       {
+               stage->content.body.splice(uniform_insert_point, iface.members.body);
+               nodes_to_remove.insert(&iface);
+       }
 }
 
 } // namespace SL
index 59e21c35a3644a78451f85b1100bc47b107a9547..a391aab0f20f7dbcb8bd0a4477434570437e7697 100644 (file)
@@ -8,12 +8,12 @@ namespace Msp {
 namespace GL {
 namespace SL {
 
-class DefaultPrecisionGenerator: private BlockModifier
+class DefaultPrecisionGenerator: private TraversingVisitor
 {
 private:
-       Stage::Type stage_type;
-       bool toplevel;
+       Stage *stage;
        std::set<std::string> have_default;
+       NodeList<Statement>::iterator insert_point;
 
 public:
        DefaultPrecisionGenerator();
@@ -24,21 +24,24 @@ private:
        virtual void visit(Block &);
        virtual void visit(Precision &);
        virtual void visit(VariableDeclaration &);
-       using BlockModifier::visit;
+       using TraversingVisitor::visit;
 };
 
-class PrecisionRemover: private BlockModifier
+class PrecisionRemover: private TraversingVisitor
 {
+private:
+       std::set<Node *> nodes_to_remove;
+
 public:
-       void apply(Stage &s) { visit(s.content); }
+       void apply(Stage &);
 
 private:
        virtual void visit(Precision &);
        virtual void visit(VariableDeclaration &);
-       using BlockModifier::visit;
+       using TraversingVisitor::visit;
 };
 
-class LegacyConverter: private BlockModifier
+class LegacyConverter: private TraversingVisitor
 {
 private:
        Stage *stage;
@@ -46,6 +49,8 @@ private:
        Version target_version;
        std::string type;
        VariableDeclaration *frag_out;
+       NodeList<Statement>::iterator uniform_insert_point;
+       std::set<Node *> nodes_to_remove;
 
 public:
        LegacyConverter();
@@ -53,6 +58,7 @@ public:
        virtual void apply(Stage &);
 
 private:
+       virtual void visit(Block &);
        bool check_version(const Version &) const;
        bool check_extension(const Extension &) const;
        bool supports_unified_interface_syntax() const;
@@ -66,7 +72,7 @@ private:
        virtual void visit(VariableDeclaration &);
        bool supports_interface_blocks(const std::string &) const;
        virtual void visit(InterfaceBlock &);
-       using BlockModifier::visit;
+       using TraversingVisitor::visit;
 };
 
 } // namespace SL
index 379a6871ae754aef88ec7b589997493e025733e1..ac16c31f889dce86ae02541121802d84b59a496e 100644 (file)
@@ -12,13 +12,19 @@ DeclarationCombiner::DeclarationCombiner():
        toplevel(true)
 { }
 
+void DeclarationCombiner::apply(Stage &stage)
+{
+       visit(stage.content);
+       NodeRemover().apply(stage, nodes_to_remove);
+}
+
 void DeclarationCombiner::visit(Block &block)
 {
        if(!toplevel)
                return;
 
        SetForScope<bool> set(toplevel, false);
-       BlockModifier::visit(block);
+       TraversingVisitor::visit(block);
 }
 
 void DeclarationCombiner::visit(FunctionDeclaration &func)
@@ -64,7 +70,7 @@ void DeclarationCombiner::visit(VariableDeclaration &var)
                        else
                                ptr->layout = var.layout;
                }
-               remove_node = true;
+               nodes_to_remove.insert(&var);
        }
        else
                ptr = &var;
@@ -240,7 +246,8 @@ void FunctionResolver::visit(FunctionDeclaration &func)
 
 InterfaceGenerator::InterfaceGenerator():
        stage(0),
-       scope_level(0)
+       scope_level(0),
+       current_block(0)
 { }
 
 string InterfaceGenerator::get_out_prefix(Stage::Type type)
@@ -260,26 +267,20 @@ void InterfaceGenerator::apply(Stage &s)
                in_prefix = get_out_prefix(stage->previous->type);
        out_prefix = get_out_prefix(stage->type);
        visit(s.content);
+       NodeRemover().apply(s, nodes_to_remove);
 }
 
 void InterfaceGenerator::visit(Block &block)
 {
        SetForScope<unsigned> set(scope_level, scope_level+1);
-       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); )
+       SetForScope<Block *> set_block(current_block, &block);
+       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
        {
-               (*i)->visit(*this);
-
+               assignment_insert_point = i;
                if(scope_level==1)
-               {
-                       for(map<string, RefPtr<VariableDeclaration> >::iterator j=iface_declarations.begin(); j!=iface_declarations.end(); ++j)
-                       {
-                               NodeList<Statement>::iterator k = block.body.insert(i, j->second);
-                               (*k)->visit(*this);
-                       }
-                       iface_declarations.clear();
-               }
+                       iface_insert_point = i;
 
-               apply_and_increment(block, i);
+               (*i)->visit(*this);
        }
 }
 
@@ -292,7 +293,7 @@ string InterfaceGenerator::change_prefix(const string &name, const string &prefi
 bool InterfaceGenerator::generate_interface(VariableDeclaration &var, const string &iface, const string &name)
 {
        const map<string, VariableDeclaration *> &stage_vars = (iface=="in" ? stage->in_variables : stage->out_variables);
-       if(stage_vars.count(name) || iface_declarations.count(name))
+       if(stage_vars.count(name))
                return false;
 
        VariableDeclaration* iface_var = new VariableDeclaration;
@@ -309,7 +310,12 @@ bool InterfaceGenerator::generate_interface(VariableDeclaration &var, const stri
                iface_var->array_size = var.array_size;
        if(iface=="in")
                iface_var->linked_declaration = &var;
-       iface_declarations[name] = iface_var;
+       stage->content.body.insert(iface_insert_point, iface_var);
+       {
+               SetForScope<unsigned> set_level(scope_level, 1);
+               SetForScope<Block *> set_block(current_block, &stage->content);
+               iface_var->visit(*this);
+       }
 
        return true;
 }
@@ -325,8 +331,8 @@ ExpressionStatement &InterfaceGenerator::insert_assignment(const string &left, E
 
        ExpressionStatement *stmt = new ExpressionStatement;
        stmt->expression = assign;
+       current_block->body.insert(assignment_insert_point, stmt);
        stmt->visit(*this);
-       insert_nodes.push_back(stmt);
 
        return *stmt;
 }
@@ -335,8 +341,6 @@ void InterfaceGenerator::visit(VariableReference &var)
 {
        if(var.declaration || !stage->previous)
                return;
-       if(iface_declarations.count(var.name))
-               return;
 
        const map<string, VariableDeclaration *> &prev_out = stage->previous->out_variables;
        map<string, VariableDeclaration *>::const_iterator i = prev_out.find(var.name);
@@ -357,7 +361,7 @@ void InterfaceGenerator::visit(VariableDeclaration &var)
                        stage->out_variables[var.name] = &var;
                else if(generate_interface(var, "out", change_prefix(var.name, string())))
                {
-                       remove_node = true;
+                       nodes_to_remove.insert(&var);
                        if(var.init_expression)
                        {
                                ExpressionStatement &stmt = insert_assignment(var.name, var.init_expression->clone());
@@ -393,9 +397,6 @@ void InterfaceGenerator::visit(Passthrough &pass)
 
        for(map<string, VariableDeclaration *>::const_iterator i=stage->in_variables.begin(); i!=stage->in_variables.end(); ++i)
                pass_vars.push_back(i->second);
-       for(map<string, RefPtr<VariableDeclaration> >::const_iterator i=iface_declarations.begin(); i!=iface_declarations.end(); ++i)
-               if(i->second->interface=="in")
-                       pass_vars.push_back(i->second.get());
 
        if(stage->previous)
        {
@@ -449,7 +450,7 @@ void InterfaceGenerator::visit(Passthrough &pass)
                        insert_assignment(out_name, ref);
        }
 
-       remove_node = true;
+       nodes_to_remove.insert(&pass);
 }
 
 
index 0c2a0ad8fce4b03aee235202f9a4eaebe0937d59..ce95754eb4d68f9d28addbbceaa06839988be507 100644 (file)
@@ -11,23 +11,24 @@ namespace Msp {
 namespace GL {
 namespace SL {
 
-class DeclarationCombiner: private BlockModifier
+class DeclarationCombiner: private TraversingVisitor
 {
 private:
        bool toplevel;
        std::map<std::string, std::vector<FunctionDeclaration *> > functions;
        std::map<std::string, VariableDeclaration *> variables;
+       std::set<Node *> nodes_to_remove;
 
 public:
        DeclarationCombiner();
 
-       void apply(Stage &s) { visit(s.content); }
+       void apply(Stage &);
 
 private:
        virtual void visit(Block &);
        virtual void visit(FunctionDeclaration &);
        virtual void visit(VariableDeclaration &);
-       using BlockModifier::visit;
+       using TraversingVisitor::visit;
 };
 
 class VariableResolver: private TraversingVisitor
@@ -72,14 +73,17 @@ private:
        using TraversingVisitor::visit;
 };
 
-class InterfaceGenerator: private BlockModifier
+class InterfaceGenerator: private TraversingVisitor
 {
 private:
        Stage *stage;
        std::string in_prefix;
        std::string out_prefix;
        unsigned scope_level;
-       std::map<std::string, RefPtr<VariableDeclaration> > iface_declarations;
+       Block *current_block;
+       NodeList<Statement>::iterator iface_insert_point;
+       NodeList<Statement>::iterator assignment_insert_point;
+       std::set<Node *> nodes_to_remove;
 
 public:
        InterfaceGenerator();
@@ -95,7 +99,7 @@ private:
        virtual void visit(VariableReference &);
        virtual void visit(VariableDeclaration &);
        virtual void visit(Passthrough &);
-       using BlockModifier::visit;
+       using TraversingVisitor::visit;
 };
 
 class DeclarationReorderer: private TraversingVisitor
index 7bbe14da382b6dbd54fb23ce113187f7e8616de3..208e17a39b49e6fe1fd92f10b31e65808e77068b 100644 (file)
@@ -125,13 +125,25 @@ void FunctionInliner::visit(Return &ret)
 
 ConstantConditionEliminator::ConstantConditionEliminator():
        scope_level(0),
+       current_block(0),
        record_only(false)
 { }
 
+void ConstantConditionEliminator::apply(Stage &stage)
+{
+       visit(stage.content);
+       NodeRemover().apply(stage, nodes_to_remove);
+}
+
 void ConstantConditionEliminator::visit(Block &block)
 {
        SetForScope<unsigned> set(scope_level, scope_level+1);
-       BlockModifier::visit(block);
+       SetForScope<Block *> set_block(current_block, &block);
+       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); ++i)
+       {
+               insert_point = i;
+               (*i)->visit(*this);
+       }
 
        for(map<string, VariableDeclaration *>::const_iterator i=block.variables.begin(); i!=block.variables.end(); ++i)
                variable_values.erase(i->second);
@@ -163,7 +175,9 @@ void ConstantConditionEliminator::visit(Conditional &cond)
                cond.condition->visit(eval);
                if(eval.is_result_valid())
                {
-                       flatten_block(eval.get_result() ? cond.body : cond.else_body);
+                       Block &block = (eval.get_result() ? cond.body : cond.else_body);
+                       current_block->body.splice(insert_point, block.body);
+                       nodes_to_remove.insert(&cond);
                        return;
                }
        }
@@ -185,7 +199,7 @@ void ConstantConditionEliminator::visit(Iteration &iter)
                        iter.condition->visit(eval);
                        if(eval.is_result_valid() && !eval.get_result())
                        {
-                               remove_node = true;
+                               nodes_to_remove.insert(&iter);
                                return;
                        }
                }
index 2d1b20c99e72d7b1361508cc9bdda14c2424952b..ea772b3d841b1ef78171d99ebb9afef064a76dfd 100644 (file)
@@ -53,17 +53,20 @@ private:
        using TraversingVisitor::visit;
 };
 
-class ConstantConditionEliminator: private BlockModifier
+class ConstantConditionEliminator: private TraversingVisitor
 {
 private:
        unsigned scope_level;
+       Block *current_block;
        bool record_only;
        ExpressionEvaluator::ValueMap variable_values;
+       NodeList<Statement>::iterator insert_point;
+       std::set<Node *> nodes_to_remove;
 
 public:
        ConstantConditionEliminator();
 
-       void apply(Stage &s) { visit(s.content); }
+       void apply(Stage &);
 
 private:
        virtual void visit(Block &);
@@ -72,7 +75,7 @@ private:
        virtual void visit(VariableDeclaration &);
        virtual void visit(Conditional &);
        virtual void visit(Iteration &);
-       using BlockModifier::visit;
+       using TraversingVisitor::visit;
 };
 
 class UnusedVariableRemover: private TraversingVisitor
index 98777ae9b0541f97af5b4156a7725235d9cb24d6..cf0cca9612e92132e9a89f16471f788c8f7d00a1 100644 (file)
@@ -114,41 +114,11 @@ void TraversingVisitor::visit(Return &ret)
 }
 
 
-BlockModifier::BlockModifier():
-       remove_node(false)
-{ }
-
-void BlockModifier::flatten_block(Block &block)
-{
-       insert_nodes.insert(insert_nodes.end(), block.body.begin(), block.body.end());
-       remove_node = true;
-}
-
-void BlockModifier::apply_and_increment(Block &block, NodeList<Statement>::iterator &i)
-{
-       block.body.insert(i, insert_nodes.begin(), insert_nodes.end());
-       insert_nodes.clear();
-
-       if(remove_node)
-               block.body.erase(i++);
-       else
-               ++i;
-       remove_node = false;
-}
-
-void BlockModifier::visit(Block &block)
-{
-       for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); )
-       {
-               (*i)->visit(*this);
-               apply_and_increment(block, i);
-       }
-}
-
-
 NodeRemover::NodeRemover():
        stage(0),
-       to_remove(0)
+       to_remove(0),
+       anonymous(false),
+       recursive_remove(false)
 { }
 
 void NodeRemover::apply(Stage &s, const set<Node *> &tr)
@@ -158,8 +128,16 @@ void NodeRemover::apply(Stage &s, const set<Node *> &tr)
        visit(s.content);
 }
 
+void NodeRemover::remove_variable(map<string, VariableDeclaration *> &vars, VariableDeclaration &decl)
+{
+       map<string, VariableDeclaration *>::iterator i = vars.find(decl.name);
+       if(i!=vars.end() && i->second==&decl)
+               vars.erase(i);
+}
+
 void NodeRemover::visit(Block &block)
 {
+       blocks.push_back(&block);
        for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); )
        {
                (*i)->visit(*this);
@@ -168,14 +146,24 @@ void NodeRemover::visit(Block &block)
                else
                        ++i;
        }
+       blocks.pop_back();
+}
+
+void NodeRemover::visit(StructDeclaration &strct)
+{
+       if(to_remove->count(&strct))
+               blocks.back()->types.erase(strct.name);
 }
 
 void NodeRemover::visit(VariableDeclaration &var)
 {
-       if(to_remove->count(&var))
+       if(recursive_remove || to_remove->count(&var))
        {
-               stage->in_variables.erase(var.name);
-               stage->out_variables.erase(var.name);
+               remove_variable(blocks.back()->variables, var);
+               if(anonymous && blocks.size()>1)
+                       remove_variable(blocks[blocks.size()-2]->variables, var);
+               remove_variable(stage->in_variables, var);
+               remove_variable(stage->out_variables, var);
                stage->locations.erase(var.name);
                if(var.linked_declaration)
                        var.linked_declaration->linked_declaration = 0;
@@ -184,6 +172,13 @@ void NodeRemover::visit(VariableDeclaration &var)
                var.init_expression = 0;
 }
 
+void NodeRemover::visit(InterfaceBlock &iface)
+{
+       SetFlag set_anon(anonymous);
+       SetFlag set_recursive(recursive_remove, recursive_remove || to_remove->count(&iface));
+       TraversingVisitor::visit(iface);
+}
+
 void NodeRemover::visit(Iteration &iter)
 {
        if(to_remove->count(iter.init_statement.get()))
index 8bfcac0fd490b85bb2db1c32dfca371a4e786363..9ca910becaa38656c15382ca3f8be0a97f499ba4 100644 (file)
@@ -66,22 +66,6 @@ public:
        virtual void visit(Return &);
 };
 
-class BlockModifier: public TraversingVisitor
-{
-protected:
-       bool remove_node;
-       std::vector<RefPtr<Statement> > insert_nodes;
-
-       BlockModifier();
-
-       void flatten_block(Block &);
-       void apply_and_increment(Block &, NodeList<Statement>::iterator &);
-
-public:
-       using TraversingVisitor::visit;
-       virtual void visit(Block &);
-};
-
 template<typename T>
 class NodeGatherer: private TraversingVisitor
 {
@@ -101,6 +85,9 @@ class NodeRemover: private TraversingVisitor
 private:
        Stage *stage;
        const std::set<Node *> *to_remove;
+       std::vector<Block *> blocks;
+       bool anonymous;
+       bool recursive_remove;
 
 public:
        NodeRemover();
@@ -108,9 +95,13 @@ public:
        void apply(Stage &, const std::set<Node *> &);
 
 private:
+       void remove_variable(std::map<std::string, VariableDeclaration *> &, VariableDeclaration &);
+
        using TraversingVisitor::visit;
        virtual void visit(Block &);
+       virtual void visit(StructDeclaration &);
        virtual void visit(VariableDeclaration &);
+       virtual void visit(InterfaceBlock &);
        virtual void visit(Iteration &);
 };