]> git.tdb.fi Git - libs/gl.git/commitdiff
Refactor the way of applying visitors to stages
authorMikko Rasa <tdb@tdb.fi>
Mon, 15 Feb 2021 14:01:55 +0000 (16:01 +0200)
committerMikko Rasa <tdb@tdb.fi>
Mon, 15 Feb 2021 15:09:51 +0000 (17:09 +0200)
13 files changed:
source/glsl/builtin.cpp
source/glsl/compatibility.cpp
source/glsl/compatibility.h
source/glsl/compiler.cpp
source/glsl/compiler.h
source/glsl/generate.cpp
source/glsl/generate.h
source/glsl/optimize.cpp
source/glsl/optimize.h
source/glsl/output.cpp
source/glsl/output.h
source/glsl/visitor.cpp
source/glsl/visitor.h

index ee692462299dc7397a239a994a1b57228d385c4b..c9118f69d3ffa6e596d44111f8fc9e30b0835c42 100644 (file)
@@ -34,8 +34,7 @@ Module *create_builtins_module()
        Module *module = new Module(parser.parse(builtins_src, "<builtin>"));
        for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
        {
-               VariableResolver resolver;
-               i->content.visit(resolver);
+               VariableResolver().visit(i->content);
                for(map<string, VariableDeclaration *>::iterator j=i->content.variables.begin(); j!=i->content.variables.end(); ++j)
                        j->second->linked_declaration = j->second;
        }
index a3b2882dc538b0db82d603c12f792c3846ca36b5..398a7bd30a76eaa4b141b400d9cd4e08d0e87958 100644 (file)
@@ -14,9 +14,16 @@ namespace GL {
 namespace SL {
 
 DefaultPrecisionGenerator::DefaultPrecisionGenerator():
+       stage_type(Stage::SHARED),
        toplevel(true)
 { }
 
+void DefaultPrecisionGenerator::apply(Stage &stage)
+{
+       SetForScope<Stage::Type> set_stage(stage_type, stage.type);
+       visit(stage.content);
+}
+
 void DefaultPrecisionGenerator::visit(Block &block)
 {
        if(toplevel)
@@ -25,7 +32,7 @@ void DefaultPrecisionGenerator::visit(Block &block)
                BlockModifier::visit(block);
        }
        else
-               StageVisitor::visit(block);
+               TraversingVisitor::visit(block);
 }
 
 void DefaultPrecisionGenerator::visit(Precision &prec)
@@ -49,7 +56,7 @@ void DefaultPrecisionGenerator::visit(VariableDeclaration &var)
                Precision *prec = new Precision;
                if(!type.compare(0, 7, "sampler"))
                        prec->precision = "lowp";
-               else if(stage->type==Stage::FRAGMENT)
+               else if(stage_type==Stage::FRAGMENT)
                        prec->precision = "mediump";
                else
                        prec->precision = "highp";
@@ -106,6 +113,12 @@ bool LegacyConverter::check_extension(const Extension &extension) const
        return true;
 }
 
+void LegacyConverter::apply(Stage &s)
+{
+       SetForScope<Stage *> set_stage(stage, &s);
+       visit(s.content);
+}
+
 bool LegacyConverter::supports_unified_interface_syntax() const
 {
        if(target_api==OPENGL_ES2)
index f33768767b0b0909614e9c8181cf26a32cf66487..732f3477db5f9b401550ff999bde3d6f50af0718 100644 (file)
@@ -11,13 +11,16 @@ namespace SL {
 class DefaultPrecisionGenerator: public BlockModifier
 {
 private:
+       Stage::Type stage_type;
        bool toplevel;
        std::set<std::string> have_default;
 
 public:
        DefaultPrecisionGenerator();
 
-       using StageVisitor::visit;
+       void apply(Stage &);
+
+       using BlockModifier::visit;
        virtual void visit(Block &);
        virtual void visit(Precision &);
        virtual void visit(VariableDeclaration &);
@@ -26,7 +29,9 @@ public:
 class PrecisionRemover: public BlockModifier
 {
 public:
-       using StageVisitor::visit;
+       void apply(Stage &s) { visit(s.content); }
+
+       using BlockModifier::visit;
        virtual void visit(Precision &);
        virtual void visit(VariableDeclaration &);
 };
@@ -34,6 +39,7 @@ public:
 class LegacyConverter: public BlockModifier
 {
 private:
+       Stage *stage;
        GLApi target_api;
        Version target_version;
        std::string type;
@@ -46,7 +52,10 @@ public:
 private:
        bool check_version(const Version &) const;
        bool check_extension(const Extension &) const;
-       using StageVisitor::visit;
+public:
+       using BlockModifier::visit;
+       virtual void apply(Stage &);
+private:
        bool supports_unified_interface_syntax() const;
        virtual void visit(VariableReference &);
        virtual void visit(Assignment &);
index e6c035775a89b0ff1aaa98464755d031a73fed98..21060cb4b783975f72da2a039de2f6a07b86888f 100644 (file)
@@ -65,17 +65,19 @@ void Compiler::add_shaders(Program &program)
        {
                for(list<Stage>::iterator i=module->stages.begin(); i!=module->stages.end(); ++i)
                {
+                       string stage_src = Formatter().apply(*i);
+
                        if(i->type==Stage::VERTEX)
                        {
-                               program.attach_shader_owned(new VertexShader(apply<Formatter>(*i)));
+                               program.attach_shader_owned(new VertexShader(stage_src));
                                for(map<string, unsigned>::iterator j=i->locations.begin(); j!=i->locations.end(); ++j)
                                        program.bind_attribute(j->second, j->first);
                        }
                        else if(i->type==Stage::GEOMETRY)
-                               program.attach_shader_owned(new GeometryShader(apply<Formatter>(*i)));
+                               program.attach_shader_owned(new GeometryShader(stage_src));
                        else if(i->type==Stage::FRAGMENT)
                        {
-                               program.attach_shader_owned(new FragmentShader(apply<Formatter>(*i)));
+                               program.attach_shader_owned(new FragmentShader(stage_src));
                                if(EXT_gpu_shader4)
                                {
                                        for(map<string, unsigned>::iterator j=i->locations.begin(); j!=i->locations.end(); ++j)
@@ -124,10 +126,10 @@ void Compiler::add_shaders(Program &program)
 
 void Compiler::append_module(Module &mod)
 {
-       vector<Import *> imports = apply<NodeGatherer<Import> >(mod.shared);
+       vector<Import *> imports = NodeGatherer<Import>().apply(mod.shared);
        for(vector<Import *>::iterator i=imports.begin(); i!=imports.end(); ++i)
                import((*i)->module);
-       apply<NodeRemover>(mod.shared, set<Node *>(imports.begin(), imports.end()));
+       NodeRemover(set<Node *>(imports.begin(), imports.end())).apply(mod.shared);
 
        append_stage(mod.shared);
        for(list<Stage>::iterator i=mod.stages.begin(); i!=mod.stages.end(); ++i)
@@ -160,7 +162,7 @@ void Compiler::append_stage(Stage &stage)
                target->required_version = stage.required_version;
        for(NodeList<Statement>::iterator i=stage.content.body.begin(); i!=stage.content.body.end(); ++i)
                target->content.body.push_back(*i);
-       apply<DeclarationCombiner>(*target);
+       DeclarationCombiner().apply(*target);
 }
 
 void Compiler::process()
@@ -198,27 +200,27 @@ void Compiler::generate(Stage &stage)
                stage.required_version = module->shared.required_version;
        inject_block(stage.content, module->shared.content);
 
-       apply<DeclarationReorderer>(stage);
-       apply<FunctionResolver>(stage);
-       apply<VariableResolver>(stage);
-       apply<InterfaceGenerator>(stage);
-       apply<VariableResolver>(stage);
-       apply<DeclarationReorderer>(stage);
-       apply<FunctionResolver>(stage);
-       apply<LegacyConverter>(stage);
+       DeclarationReorderer().apply(stage);
+       FunctionResolver().apply(stage);
+       VariableResolver().apply(stage);
+       InterfaceGenerator().apply(stage);
+       VariableResolver().apply(stage);
+       DeclarationReorderer().apply(stage);
+       FunctionResolver().apply(stage);
+       LegacyConverter().apply(stage);
 }
 
 bool Compiler::optimize(Stage &stage)
 {
-       apply<ConstantConditionEliminator>(stage);
+       ConstantConditionEliminator().apply(stage);
 
-       set<FunctionDeclaration *> inlineable = apply<InlineableFunctionLocator>(stage);
-       apply<FunctionInliner>(stage, inlineable);
+       set<FunctionDeclaration *> inlineable = InlineableFunctionLocator().apply(stage);
+       FunctionInliner(inlineable).apply(stage);
 
-       set<Node *> unused = apply<UnusedVariableLocator>(stage);
-       set<Node *> unused2 = apply<UnusedFunctionLocator>(stage);
+       set<Node *> unused = UnusedVariableLocator().apply(stage);
+       set<Node *> unused2 = UnusedFunctionLocator().apply(stage);
        unused.insert(unused2.begin(), unused2.end());
-       apply<NodeRemover>(stage, unused);
+       NodeRemover(unused).apply(stage);
 
        return !unused.empty();
 }
@@ -226,9 +228,9 @@ bool Compiler::optimize(Stage &stage)
 void Compiler::finalize(Stage &stage)
 {
        if(get_gl_api()==OPENGL_ES2)
-               apply<DefaultPrecisionGenerator>(stage);
+               DefaultPrecisionGenerator().apply(stage);
        else
-               apply<PrecisionRemover>(stage);
+               PrecisionRemover().apply(stage);
 }
 
 void Compiler::inject_block(Block &target, const Block &source)
@@ -238,22 +240,6 @@ void Compiler::inject_block(Block &target, const Block &source)
                target.body.insert(insert_point, (*i)->clone());
 }
 
-template<typename T>
-typename T::ResultType Compiler::apply(Stage &stage)
-{
-       T visitor;
-       visitor.apply(stage);
-       return visitor.get_result();
-}
-
-template<typename T, typename A>
-typename T::ResultType Compiler::apply(Stage &stage, const A &arg)
-{
-       T visitor(arg);
-       visitor.apply(stage);
-       return visitor.get_result();
-}
-
 } // namespace SL
 } // namespace GL
 } // namespace Msp
index c8cb565fa1270c12d94c290a8479b0adef6f2084..2afc2f560bb00c17ecbdbfbc10f501ef11ad2813 100644 (file)
@@ -36,10 +36,6 @@ private:
        bool optimize(Stage &);
        void finalize(Stage &);
        static void inject_block(Block &, const Block &);
-       template<typename T>
-       static typename T::ResultType apply(Stage &);
-       template<typename T, typename A>
-       static typename T::ResultType apply(Stage &, const A &);
 };
 
 } // namespace SL
index 69a8dfb16d1eb62cccfbbba342418c0868bff69a..f3c57a2f0f9190352860bc5f7bc22ff634433123 100644 (file)
@@ -78,13 +78,12 @@ VariableResolver::VariableResolver():
        self_referencing(false)
 { }
 
-void VariableResolver::apply(Stage &s)
+void VariableResolver::apply(Stage &stage)
 {
-       SetForScope<Stage *> set(stage, &s);
-       Stage *builtins = get_builtins(stage->type);
+       Stage *builtins = get_builtins(stage.type);
        if(builtins)
                blocks.push_back(&builtins->content);
-       stage->content.visit(*this);
+       visit(stage.content);
        if(builtins)
                blocks.pop_back();
 }
@@ -240,6 +239,7 @@ void FunctionResolver::visit(FunctionDeclaration &func)
 
 
 InterfaceGenerator::InterfaceGenerator():
+       stage(0),
        scope_level(0)
 { }
 
@@ -259,7 +259,7 @@ void InterfaceGenerator::apply(Stage &s)
        if(stage->previous)
                in_prefix = get_out_prefix(stage->previous->type);
        out_prefix = get_out_prefix(stage->type);
-       stage->content.visit(*this);
+       visit(s.content);
 }
 
 void InterfaceGenerator::visit(Block &block)
@@ -471,7 +471,7 @@ void DeclarationReorderer::visit(Block &block)
 {
        SetForScope<unsigned> set(scope_level, scope_level+1);
        if(scope_level>1)
-               return StageVisitor::visit(block);
+               return TraversingVisitor::visit(block);
 
        NodeList<Statement>::iterator struct_insert_point = block.body.end();
        NodeList<Statement>::iterator variable_insert_point = block.body.end();
@@ -547,7 +547,7 @@ void DeclarationReorderer::visit(Block &block)
 
 void DeclarationReorderer::visit(VariableDeclaration &var)
 {
-       StageVisitor::visit(var);
+       TraversingVisitor::visit(var);
        kind = VARIABLE;
 }
 
index a09c552f1fd9f05c5329edabae112ece237017b0..90cf272219b9f2ba6bde8b5c4e6c349f5676218a 100644 (file)
@@ -21,13 +21,15 @@ private:
 public:
        DeclarationCombiner();
 
-       using StageVisitor::visit;
+       void apply(Stage &s) { visit(s.content); }
+
+       using BlockModifier::visit;
        virtual void visit(Block &);
        virtual void visit(FunctionDeclaration &);
        virtual void visit(VariableDeclaration &);
 };
 
-class VariableResolver: public StageVisitor
+class VariableResolver: public TraversingVisitor
 {
 private:
        std::vector<Block *> blocks;
@@ -41,8 +43,9 @@ private:
 public:
        VariableResolver();
 
-       virtual void apply(Stage &);
-       using StageVisitor::visit;
+       void apply(Stage &);
+
+       using TraversingVisitor::visit;
        virtual void visit(Block &);
        virtual void visit(VariableReference &);
        virtual void visit(MemberAccess &);
@@ -53,13 +56,15 @@ public:
        virtual void visit(InterfaceBlock &);
 };
 
-class FunctionResolver: public StageVisitor
+class FunctionResolver: public TraversingVisitor
 {
 private:
        std::map<std::string, std::vector<FunctionDeclaration *> > functions;
 
 public:
-       using StageVisitor::visit;
+       void apply(Stage &s) { visit(s.content); }
+
+       using TraversingVisitor::visit;
        virtual void visit(FunctionCall &);
        virtual void visit(FunctionDeclaration &);
 };
@@ -67,6 +72,7 @@ public:
 class InterfaceGenerator: public BlockModifier
 {
 private:
+       Stage *stage;
        std::string in_prefix;
        std::string out_prefix;
        unsigned scope_level;
@@ -75,9 +81,10 @@ private:
 public:
        InterfaceGenerator();
 
+       void apply(Stage &);
+
        static std::string get_out_prefix(Stage::Type);
-       virtual void apply(Stage &);
-       using StageVisitor::visit;
+       using BlockModifier::visit;
        virtual void visit(Block &);
        std::string change_prefix(const std::string &, const std::string &) const;
        bool generate_interface(VariableDeclaration &, const std::string &, const std::string &);
@@ -87,7 +94,7 @@ public:
        virtual void visit(Passthrough &);
 };
 
-class DeclarationReorderer: public StageVisitor
+class DeclarationReorderer: public TraversingVisitor
 {
 private:
        enum DeclarationKind
@@ -107,7 +114,9 @@ private:
 public:
        DeclarationReorderer();
 
-       using StageVisitor::visit;
+       void apply(Stage &s) { visit(s.content); }
+
+       using TraversingVisitor::visit;
        virtual void visit(Block &);
        virtual void visit(FunctionCall &);
        virtual void visit(InterfaceLayout &) { kind = LAYOUT; }
index f4d1680bccfbe76d1dd3ddb4829390fdee123f95..c76cae1b50f34a69ade9d31ead0c563250d24ae6 100644 (file)
@@ -217,10 +217,10 @@ UnusedVariableLocator::UnusedVariableLocator():
        global_scope(true)
 { }
 
-void UnusedVariableLocator::apply(Stage &s)
+const set<Node *> &UnusedVariableLocator::apply(Stage &s)
 {
        variables.push_back(BlockVariableMap());
-       StageVisitor::apply(s);
+       visit(s.content);
        BlockVariableMap &global_variables = variables.back();
        for(BlockVariableMap::iterator i=global_variables.begin(); i!=global_variables.end(); ++i)
        {
@@ -233,6 +233,8 @@ void UnusedVariableLocator::apply(Stage &s)
                }
        }
        variables.pop_back();
+
+       return unused_nodes;
 }
 
 void UnusedVariableLocator::visit(VariableReference &var)
index 698aa0b2617efd3ef1f6a98348aa754b19f04862..2e28d48a0dcbfe711057572cc7f5e1574dbe8789 100644 (file)
@@ -10,11 +10,8 @@ namespace Msp {
 namespace GL {
 namespace SL {
 
-class InlineableFunctionLocator: public StageVisitor
+class InlineableFunctionLocator: public TraversingVisitor
 {
-public:
-       typedef std::set<FunctionDeclaration *> ResultType;
-
 private:
        std::map<FunctionDeclaration *, unsigned> refcounts;
        std::set<FunctionDeclaration *> inlineable;
@@ -23,13 +20,14 @@ private:
 public:
        InlineableFunctionLocator();
 
-       const ResultType &get_result() const { return inlineable; }
-       using StageVisitor::visit;
+       const std::set<FunctionDeclaration *> &apply(Stage &s) { visit(s.content); return inlineable; }
+
+       using TraversingVisitor::visit;
        virtual void visit(FunctionCall &);
        virtual void visit(FunctionDeclaration &);
 };
 
-class FunctionInliner: public StageVisitor
+class FunctionInliner: public TraversingVisitor
 {
 private:
        std::set<FunctionDeclaration *> inlineable;
@@ -40,10 +38,12 @@ public:
        FunctionInliner();
        FunctionInliner(const std::set<FunctionDeclaration *> &);
 
+       void apply(Stage &s) { visit(s.content); }
+
 private:
        void visit_and_inline(RefPtr<Expression> &);
 public:
-       using StageVisitor::visit;
+       using TraversingVisitor::visit;
        virtual void visit(Block &);
        virtual void visit(UnaryExpression &);
        virtual void visit(BinaryExpression &);
@@ -63,7 +63,9 @@ private:
 public:
        ConstantConditionEliminator();
 
-       using StageVisitor::visit;
+       void apply(Stage &s) { visit(s.content); }
+
+       using BlockModifier::visit;
        virtual void visit(Block &);
        virtual void visit(UnaryExpression &);
        virtual void visit(Assignment &);
@@ -72,11 +74,8 @@ public:
        virtual void visit(Iteration &);
 };
 
-class UnusedVariableLocator: public StageVisitor
+class UnusedVariableLocator: public TraversingVisitor
 {
-public:
-       typedef std::set<Node *> ResultType;
-
 private:
        struct VariableInfo
        {
@@ -102,38 +101,39 @@ private:
 public:
        UnusedVariableLocator();
 
-       virtual void apply(Stage &);
-       const ResultType &get_result() const { return unused_nodes; }
-private:
-       using StageVisitor::visit;
+       const std::set<Node *> &apply(Stage &);
+
+       using TraversingVisitor::visit;
        virtual void visit(VariableReference &);
        virtual void visit(MemberAccess &);
        virtual void visit(BinaryExpression &);
        virtual void visit(Assignment &);
+private:
        void record_assignment(VariableDeclaration &, Node &, bool);
        void clear_assignments(VariableInfo &, bool);
+public:
        virtual void visit(ExpressionStatement &);
        virtual void visit(StructDeclaration &);
        virtual void visit(VariableDeclaration &);
        virtual void visit(InterfaceBlock &);
        virtual void visit(FunctionDeclaration &);
+private:
        void merge_down_variables();
+public:
        virtual void visit(Conditional &);
        virtual void visit(Iteration &);
 };
 
-class UnusedFunctionLocator: public StageVisitor
+class UnusedFunctionLocator: public TraversingVisitor
 {
-public:
-       typedef std::set<Node *> ResultType;
-
 private:
        std::set<Node *> unused_nodes;
        std::set<FunctionDeclaration *> used_definitions;
 
 public:
-       const ResultType &get_result() const { return unused_nodes; }
-       using StageVisitor::visit;
+       const std::set<Node *> &apply(Stage &s) { visit(s.content); return unused_nodes; }
+
+       using TraversingVisitor::visit;
        virtual void visit(FunctionCall &);
        virtual void visit(FunctionDeclaration &);
 };
index dfd93470dd101dfac8ca4c7f8a07be588bb11e9f..6e88fd1f870e6ba554a0fbfa0ddaba93d6347720 100644 (file)
@@ -9,14 +9,17 @@ namespace GL {
 namespace SL {
 
 Formatter::Formatter():
+       stage(0),
        source_index(0),
        source_line(1),
        indent(0),
        parameter_list(false)
 { }
 
-void Formatter::apply(Stage &s)
+const string &Formatter::apply(Stage &s)
 {
+       SetForScope<Stage *> set_stage(stage, &s);
+
        GLApi api = get_gl_api();
        const Version &ver = s.required_version;
 
@@ -33,7 +36,9 @@ void Formatter::apply(Stage &s)
        if(!s.required_extensions.empty())
                formatted += '\n';
 
-       StageVisitor::apply(s);
+       visit(s.content);
+
+       return formatted;
 }
 
 void Formatter::append(const string &text)
index b58c3f96f2ed7560c5ce85db1d5ffb6380aad87b..06bb5fd18f7069001b0275e29cc2d7986d40b0a5 100644 (file)
@@ -8,12 +8,10 @@ namespace Msp {
 namespace GL {
 namespace SL {
 
-class Formatter: public StageVisitor
+class Formatter: public TraversingVisitor
 {
-public:
-       typedef std::string ResultType;
-
 private:
+       Stage *stage;
        std::string formatted;
        unsigned source_index;
        unsigned source_line;
@@ -24,14 +22,14 @@ private:
 public:
        Formatter();
 
-       virtual void apply(Stage &);
-       const std::string &get_result() const { return formatted; }
+       const std::string &apply(Stage &);
+
 private:
        void append(const std::string &);
        void append(char);
        void set_source(unsigned, unsigned);
 public:
-       using StageVisitor::visit;
+       using TraversingVisitor::visit;
        virtual void visit(Block &);
        virtual void visit(Literal &);
        virtual void visit(ParenthesizedExpression &);
index 58d323ab540fc3be3a02c70a7c0a81dd333977b8..717acad1fdf6dbbd0f03e4e4a9a64ce3f2a4bcd2 100644 (file)
@@ -114,17 +114,6 @@ void TraversingVisitor::visit(Return &ret)
 }
 
 
-StageVisitor::StageVisitor():
-       stage(0)
-{ }
-
-void StageVisitor::apply(Stage &s)
-{
-       SetForScope<Stage *> set(stage, &s);
-       stage->content.visit(*this);
-}
-
-
 BlockModifier::BlockModifier():
        remove_node(false)
 { }
@@ -158,9 +147,16 @@ void BlockModifier::visit(Block &block)
 
 
 NodeRemover::NodeRemover(const set<Node *> &r):
+       stage(0),
        to_remove(r)
 { }
 
+void NodeRemover::apply(Stage &s)
+{
+       SetForScope<Stage *> set_stage(stage, &s);
+       visit(s.content);
+}
+
 void NodeRemover::visit(Block &block)
 {
        for(NodeList<Statement>::iterator i=block.body.begin(); i!=block.body.end(); )
index 582628bbedc49ecfe695c9eacd7909d861bfa071..e1c45b5c89de152976cfd94549c018f8e7607116 100644 (file)
@@ -66,22 +66,7 @@ public:
        virtual void visit(Return &);
 };
 
-class StageVisitor: public TraversingVisitor
-{
-public:
-       typedef void ResultType;
-
-protected:
-       Stage *stage;
-
-       StageVisitor();
-
-public:
-       virtual void apply(Stage &);
-       void get_result() const { }
-};
-
-class BlockModifier: public StageVisitor
+class BlockModifier: public TraversingVisitor
 {
 protected:
        bool remove_node;
@@ -93,35 +78,35 @@ protected:
        void apply_and_increment(Block &, NodeList<Statement>::iterator &);
 
 public:
-       using StageVisitor::visit;
+       using TraversingVisitor::visit;
        virtual void visit(Block &);
 };
 
 template<typename T>
-class NodeGatherer: public StageVisitor
+class NodeGatherer: public TraversingVisitor
 {
-public:
-       typedef std::vector<T *> ResultType;
-
 private:
        std::vector<T *> nodes;
 
 public:
-       const ResultType &get_result() const { return nodes; }
-       using StageVisitor::visit;
+       const std::vector<T *> &apply(Stage &s) { visit(s.content); return nodes; }
+
+       using TraversingVisitor::visit;
        virtual void visit(T &n) { nodes.push_back(&n); }
 };
 
-class NodeRemover: public StageVisitor
+class NodeRemover: public TraversingVisitor
 {
 private:
+       Stage *stage;
        std::set<Node *> to_remove;
 
 public:
-       NodeRemover() { }
        NodeRemover(const std::set<Node *> &);
 
-       using StageVisitor::visit;
+       void apply(Stage &);
+
+       using TraversingVisitor::visit;
        virtual void visit(Block &);
        virtual void visit(VariableDeclaration &);
        virtual void visit(Iteration &);