Shader: Rework scope handling

This commit is contained in:
Jérôme Leclercq
2021-04-04 20:31:09 +02:00
parent feffcfa6e5
commit f93a5bbdc1
23 changed files with 661 additions and 755 deletions

View File

@@ -8,7 +8,6 @@
#include <Nazara/Math/Algorithm.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <optional>
@@ -22,31 +21,87 @@ namespace Nz
static const char* flipYUniformName = "_NzFlipValue";
static const char* overridenMain = "_NzMain";
struct AstAdapter : ShaderAst::AstCloner
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
struct PreVisitor : ShaderAst::AstCloner
{
using AstCloner::Clone;
std::unique_ptr<ShaderAst::DeclareFunctionStatement> Clone(ShaderAst::DeclareFunctionStatement& node) override
ShaderAst::StatementPtr Clone(ShaderAst::DeclareFunctionStatement& node) override
{
auto clone = AstCloner::Clone(node);
if (clone->name == "main")
clone->name = "_NzMain";
assert(clone->GetType() == ShaderAst::NodeType::DeclareFunctionStatement);
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
bool hasEntryPoint = false;
for (auto& attribute : func->attributes)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
if (it->second == selectedEntryPoint)
{
hasEntryPoint = true;
break;
}
}
}
if (!hasEntryPoint)
return ShaderBuilder::NoOp();
entryPoint = func;
if (func->name == "main")
func->name = "_NzMain";
return clone;
}
void Visit(ShaderAst::DeclareFunctionStatement& node)
{
if (removedEntryPoints.find(&node) != removedEntryPoints.end())
{
PushStatement(ShaderBuilder::NoOp());
return;
}
ShaderStageType selectedEntryPoint;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
};
AstCloner::Visit(node);
struct EntryFuncResolver : ShaderAst::AstScopedVisitor
{
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
if (&node != entryPoint)
return;
assert(node.parameters.size() == 1);
const ShaderAst::ExpressionType& inputType = node.parameters.front().type;
const ShaderAst::ExpressionType& outputType = node.returnType;
const Identifier* identifier;
assert(IsIdentifierType(node.parameters.front().type));
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(inputType).name);
assert(identifier);
inputIdentifier = *identifier;
assert(IsIdentifierType(outputType));
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(outputType).name);
assert(identifier);
outputIdentifier = *identifier;
}
std::unordered_set<ShaderAst::DeclareFunctionStatement*> removedEntryPoints;
Identifier inputIdentifier;
Identifier outputIdentifier;
ShaderAst::DeclareFunctionStatement* entryPoint;
};
struct Builtin
@@ -64,7 +119,6 @@ namespace Nz
struct GlslWriter::State
{
const States* states = nullptr;
ShaderAst::AstCache cache;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
unsigned int indentLevel = 0;
@@ -86,29 +140,18 @@ namespace Nz
});
std::string error;
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
if (!ShaderAst::ValidateAst(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
if (!state.entryFunc)
PreVisitor previsitor;
previsitor.selectedEntryPoint = shaderStage;
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(shader);
if (!previsitor.entryPoint)
throw std::runtime_error("missing entry point");
AstAdapter adapter;
for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions)
{
if (entryFunc != state.entryFunc)
adapter.removedEntryPoints.insert(entryFunc);
}
ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader);
state.cache.Clear();
if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache))
throw std::runtime_error("Internal error:" + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
assert(state.entryFunc);
state.entryFunc = previsitor.entryPoint;
unsigned int glslVersion;
if (m_environment.glES)
@@ -190,10 +233,14 @@ namespace Nz
AppendLine();
}
adaptedShader->Visit(*this);
PushScope();
{
adaptedShader->Visit(*this);
// Append true GLSL entry point
AppendEntryPoint(shaderStage);
// Append true GLSL entry point
AppendEntryPoint(shaderStage, adaptedShader);
}
PopScope();
return state.stream.str();
}
@@ -340,8 +387,12 @@ namespace Nz
AppendLine();
}
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage)
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader)
{
EntryFuncResolver entryResolver;
entryResolver.entryPoint = m_currentState->entryFunc;
entryResolver.ScopedVisit(shader);
AppendLine();
AppendLine("// Entry point handling");
@@ -354,15 +405,10 @@ namespace Nz
std::vector<InOutField> inputFields;
const ShaderAst::StructDescription* inputStruct = nullptr;
auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
auto HandleInOutStructs = [this, shaderStage](const Identifier& identifier, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
{
assert(IsIdentifierType(expressionType));
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::IdentifierType>(expressionType).name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier.value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier.value);
for (const auto& member : s.members)
{
@@ -426,17 +472,12 @@ namespace Nz
};
if (!m_currentState->entryFunc->parameters.empty())
{
assert(m_currentState->entryFunc->parameters.size() == 1);
const auto& parameter = m_currentState->entryFunc->parameters.front();
inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_");
}
inputStruct = HandleInOutStructs(entryResolver.inputIdentifier, inputFields, "in", "_nzInput.", "_NzIn_");
std::vector<InOutField> outputFields;
const ShaderAst::StructDescription* outputStruct = nullptr;
if (!IsNoType(m_currentState->entryFunc->returnType))
outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_");
outputStruct = HandleInOutStructs(entryResolver.outputIdentifier, outputFields, "out", "_nzOutput.", "_NzOut_");
if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition)
AppendLine("uniform float ", flipYUniformName, ";");
@@ -486,12 +527,12 @@ namespace Nz
LeaveScope();
}
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
void GlslWriter::AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
Append(".");
Append(memberIdentifier[0]);
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName);
const Identifier* identifier = FindIdentifier(structName);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
@@ -503,7 +544,7 @@ namespace Nz
const auto& member = *memberIt;
if (remainingMembers > 1)
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
AppendField(std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
}
void GlslWriter::AppendLine(const std::string& txt)
@@ -558,12 +599,10 @@ namespace Nz
{
Visit(node.structExpr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
const ShaderAst::ExpressionType& exprType = node.structExpr->cachedExpressionType.value();
assert(IsIdentifierType(exprType));
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
AppendField(std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
@@ -593,7 +632,9 @@ namespace Nz
AppendLine(")");
EnterScope();
PushScope();
statement.statement->Visit(*this);
PopScope();
LeaveScope();
first = false;
@@ -604,7 +645,9 @@ namespace Nz
AppendLine("else");
EnterScope();
PushScope();
node.elseStatement->Visit(*this);
PopScope();
LeaveScope();
}
}
@@ -698,6 +741,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
@@ -729,7 +774,7 @@ namespace Nz
EnterScope();
{
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
@@ -780,15 +825,19 @@ namespace Nz
Append(")\n");
EnterScope();
PushScope();
{
for (auto& statement : node.statements)
statement->Visit(*this);
}
PopScope();
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
RegisterStruct(node.description);
Append("struct ");
AppendLine(node.description.name);
EnterScope();
@@ -813,6 +862,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
RegisterVariable(node.varName, node.varType);
Append(node.varType);
Append(" ");
Append(node.varName);
@@ -871,6 +922,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
{
PushScope();
bool first = true;
for (const ShaderAst::StatementPtr& statement : node.statements)
{
@@ -881,6 +934,8 @@ namespace Nz
first = false;
}
PopScope();
}
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)