From a48aaa402e2aacab780805f529cde4ded7ae3f59 Mon Sep 17 00:00:00 2001
From: Mikko Rasa <tdb@tdb.fi>
Date: Fri, 2 Dec 2016 12:32:23 +0200
Subject: [PATCH] Inline trivial shader functions that are only called once

I.e. functions that consist of a single return statement.
---
 source/programcompiler.cpp | 105 +++++++++++++++++++++++++++++++++++++
 source/programcompiler.h   |  32 +++++++++++
 2 files changed, 137 insertions(+)

diff --git a/source/programcompiler.cpp b/source/programcompiler.cpp
index 7224df1e..0e749c97 100644
--- a/source/programcompiler.cpp
+++ b/source/programcompiler.cpp
@@ -173,6 +173,9 @@ bool ProgramCompiler::optimize(Stage &stage)
 {
 	apply<ConstantConditionEliminator>(stage);
 
+	set<FunctionDeclaration *> inlineable = apply<InlineableFunctionLocator>(stage);
+	apply<FunctionInliner>(stage, inlineable);
+
 	set<Node *> unused = apply<UnusedVariableLocator>(stage);
 	set<Node *> unused2 = apply<UnusedFunctionLocator>(stage);
 	unused.insert(unused2.begin(), unused2.end());
@@ -957,6 +960,108 @@ void ProgramCompiler::DeclarationReorderer::visit(Block &block)
 }
 
 
+ProgramCompiler::InlineableFunctionLocator::InlineableFunctionLocator():
+	in_function(0)
+{ }
+
+void ProgramCompiler::InlineableFunctionLocator::visit(FunctionCall &call)
+{
+	FunctionDeclaration *def = call.declaration;
+	if(def && def->definition!=def)
+		def = def->definition;
+
+	if(def)
+	{
+		unsigned &count = refcounts[def];
+		++count;
+		if(count>1 || def==in_function)
+			inlineable.erase(def);
+	}
+
+	TraversingVisitor::visit(call);
+}
+
+void ProgramCompiler::InlineableFunctionLocator::visit(FunctionDeclaration &func)
+{
+	unsigned &count = refcounts[func.definition];
+	if(!count && func.parameters.empty())
+		inlineable.insert(func.definition);
+
+	SetForScope<FunctionDeclaration *> set(in_function, &func);
+	TraversingVisitor::visit(func);
+}
+
+
+ProgramCompiler::FunctionInliner::FunctionInliner():
+	extract_result(0)
+{ }
+
+ProgramCompiler::FunctionInliner::FunctionInliner(const set<FunctionDeclaration *> &in):
+	inlineable(in),
+	extract_result(0)
+{ }
+
+void ProgramCompiler::FunctionInliner::visit_and_inline(RefPtr<Expression> &ptr)
+{
+	inline_result = 0;
+	ptr->visit(*this);
+	if(inline_result)
+		ptr = inline_result;
+}
+
+void ProgramCompiler::FunctionInliner::visit(Block &block)
+{
+	if(extract_result)
+		--extract_result;
+
+	for(list<RefPtr<Node> >::iterator i=block.body.begin(); i!=block.body.end(); ++i)
+	{
+		(*i)->visit(*this);
+		if(extract_result)
+			--extract_result;
+	}
+}
+
+void ProgramCompiler::FunctionInliner::visit(UnaryExpression &unary)
+{
+	visit_and_inline(unary.expression);
+	inline_result = 0;
+}
+
+void ProgramCompiler::FunctionInliner::visit(BinaryExpression &binary)
+{
+	visit_and_inline(binary.left);
+	visit_and_inline(binary.right);
+	inline_result = 0;
+}
+
+void ProgramCompiler::FunctionInliner::visit(FunctionCall &call)
+{
+	for(vector<RefPtr<Expression> >::iterator i=call.arguments.begin(); i!=call.arguments.end(); ++i)
+		visit_and_inline(*i);
+
+	FunctionDeclaration *def = call.declaration;
+	if(def && def->definition!=def)
+		def = def->definition;
+
+	if(def && inlineable.count(def))
+	{
+		extract_result = 2;
+		def->visit(*this);
+	}
+	else
+		inline_result = 0;
+}
+
+void ProgramCompiler::FunctionInliner::visit(Return &ret)
+{
+	TraversingVisitor::visit(ret);
+
+	if(extract_result)
+		inline_result = ret.expression->clone();
+}
+
+
 ProgramCompiler::ExpressionEvaluator::ExpressionEvaluator():
 	variable_values(0),
 	result(0.0f),
diff --git a/source/programcompiler.h b/source/programcompiler.h
index 8c4034d4..46c37930 100644
--- a/source/programcompiler.h
+++ b/source/programcompiler.h
@@ -179,6 +179,38 @@ private:
 		virtual void visit(ProgramSyntax::FunctionDeclaration &) { kind = FUNCTION; }
 	};
 
+	struct InlineableFunctionLocator: Visitor
+	{
+		typedef std::set<ProgramSyntax::FunctionDeclaration *> ResultType;
+
+		std::map<ProgramSyntax::FunctionDeclaration *, unsigned> refcounts;
+		std::set<ProgramSyntax::FunctionDeclaration *> inlineable;
+		ProgramSyntax::FunctionDeclaration *in_function;
+
+		InlineableFunctionLocator();
+
+		const ResultType &get_result() const { return inlineable; }
+		virtual void visit(ProgramSyntax::FunctionCall &);
+		virtual void visit(ProgramSyntax::FunctionDeclaration &);
+	};
+
+	struct FunctionInliner: Visitor
+	{
+		std::set<ProgramSyntax::FunctionDeclaration *> inlineable;
+		unsigned extract_result;
+		RefPtr<ProgramSyntax::Expression> inline_result;
+
+		FunctionInliner();
+		FunctionInliner(const std::set<ProgramSyntax::FunctionDeclaration *> &);
+
+		void visit_and_inline(RefPtr<ProgramSyntax::Expression> &);
+		virtual void visit(ProgramSyntax::Block &);
+		virtual void visit(ProgramSyntax::UnaryExpression &);
+		virtual void visit(ProgramSyntax::BinaryExpression &);
+		virtual void visit(ProgramSyntax::FunctionCall &);
+		virtual void visit(ProgramSyntax::Return &);
+	};
+
 	struct ExpressionEvaluator: ProgramSyntax::NodeVisitor
 	{
 		typedef std::map<ProgramSyntax::VariableDeclaration *, ProgramSyntax::Expression *> ValueMap;
-- 
2.45.2