Shader: Rework scope handling
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
#include <Nazara/Math/Algorithm.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCloner.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
||||
#include <Nazara/Shader/ShaderAstValidator.hpp>
|
||||
#include <optional>
|
||||
@@ -22,31 +21,87 @@ namespace Nz
|
||||
static const char* flipYUniformName = "_NzFlipValue";
|
||||
static const char* overridenMain = "_NzMain";
|
||||
|
||||
struct AstAdapter : ShaderAst::AstCloner
|
||||
//FIXME: Have this only once
|
||||
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
|
||||
{ "frag", ShaderStageType::Fragment },
|
||||
{ "vert", ShaderStageType::Vertex },
|
||||
};
|
||||
|
||||
struct PreVisitor : ShaderAst::AstCloner
|
||||
{
|
||||
using AstCloner::Clone;
|
||||
|
||||
std::unique_ptr<ShaderAst::DeclareFunctionStatement> Clone(ShaderAst::DeclareFunctionStatement& node) override
|
||||
ShaderAst::StatementPtr Clone(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
auto clone = AstCloner::Clone(node);
|
||||
if (clone->name == "main")
|
||||
clone->name = "_NzMain";
|
||||
assert(clone->GetType() == ShaderAst::NodeType::DeclareFunctionStatement);
|
||||
|
||||
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
|
||||
|
||||
bool hasEntryPoint = false;
|
||||
|
||||
for (auto& attribute : func->attributes)
|
||||
{
|
||||
if (attribute.type == ShaderAst::AttributeType::Entry)
|
||||
{
|
||||
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
|
||||
assert(it != s_entryPoints.end());
|
||||
|
||||
if (it->second == selectedEntryPoint)
|
||||
{
|
||||
hasEntryPoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasEntryPoint)
|
||||
return ShaderBuilder::NoOp();
|
||||
|
||||
entryPoint = func;
|
||||
|
||||
if (func->name == "main")
|
||||
func->name = "_NzMain";
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node)
|
||||
{
|
||||
if (removedEntryPoints.find(&node) != removedEntryPoints.end())
|
||||
{
|
||||
PushStatement(ShaderBuilder::NoOp());
|
||||
return;
|
||||
}
|
||||
ShaderStageType selectedEntryPoint;
|
||||
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
|
||||
};
|
||||
|
||||
AstCloner::Visit(node);
|
||||
struct EntryFuncResolver : ShaderAst::AstScopedVisitor
|
||||
{
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
|
||||
|
||||
if (&node != entryPoint)
|
||||
return;
|
||||
|
||||
assert(node.parameters.size() == 1);
|
||||
|
||||
const ShaderAst::ExpressionType& inputType = node.parameters.front().type;
|
||||
const ShaderAst::ExpressionType& outputType = node.returnType;
|
||||
|
||||
const Identifier* identifier;
|
||||
|
||||
assert(IsIdentifierType(node.parameters.front().type));
|
||||
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(inputType).name);
|
||||
assert(identifier);
|
||||
|
||||
inputIdentifier = *identifier;
|
||||
|
||||
assert(IsIdentifierType(outputType));
|
||||
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(outputType).name);
|
||||
assert(identifier);
|
||||
|
||||
outputIdentifier = *identifier;
|
||||
}
|
||||
|
||||
std::unordered_set<ShaderAst::DeclareFunctionStatement*> removedEntryPoints;
|
||||
Identifier inputIdentifier;
|
||||
Identifier outputIdentifier;
|
||||
ShaderAst::DeclareFunctionStatement* entryPoint;
|
||||
};
|
||||
|
||||
struct Builtin
|
||||
@@ -64,7 +119,6 @@ namespace Nz
|
||||
struct GlslWriter::State
|
||||
{
|
||||
const States* states = nullptr;
|
||||
ShaderAst::AstCache cache;
|
||||
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
|
||||
std::stringstream stream;
|
||||
unsigned int indentLevel = 0;
|
||||
@@ -86,29 +140,18 @@ namespace Nz
|
||||
});
|
||||
|
||||
std::string error;
|
||||
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
|
||||
if (!ShaderAst::ValidateAst(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
|
||||
if (!state.entryFunc)
|
||||
PreVisitor previsitor;
|
||||
previsitor.selectedEntryPoint = shaderStage;
|
||||
|
||||
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(shader);
|
||||
|
||||
if (!previsitor.entryPoint)
|
||||
throw std::runtime_error("missing entry point");
|
||||
|
||||
AstAdapter adapter;
|
||||
|
||||
for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions)
|
||||
{
|
||||
if (entryFunc != state.entryFunc)
|
||||
adapter.removedEntryPoints.insert(entryFunc);
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader);
|
||||
|
||||
state.cache.Clear();
|
||||
if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache))
|
||||
throw std::runtime_error("Internal error:" + error);
|
||||
|
||||
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
|
||||
assert(state.entryFunc);
|
||||
state.entryFunc = previsitor.entryPoint;
|
||||
|
||||
unsigned int glslVersion;
|
||||
if (m_environment.glES)
|
||||
@@ -190,10 +233,14 @@ namespace Nz
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
adaptedShader->Visit(*this);
|
||||
PushScope();
|
||||
{
|
||||
adaptedShader->Visit(*this);
|
||||
|
||||
// Append true GLSL entry point
|
||||
AppendEntryPoint(shaderStage);
|
||||
// Append true GLSL entry point
|
||||
AppendEntryPoint(shaderStage, adaptedShader);
|
||||
}
|
||||
PopScope();
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
@@ -340,8 +387,12 @@ namespace Nz
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage)
|
||||
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader)
|
||||
{
|
||||
EntryFuncResolver entryResolver;
|
||||
entryResolver.entryPoint = m_currentState->entryFunc;
|
||||
entryResolver.ScopedVisit(shader);
|
||||
|
||||
AppendLine();
|
||||
AppendLine("// Entry point handling");
|
||||
|
||||
@@ -354,15 +405,10 @@ namespace Nz
|
||||
std::vector<InOutField> inputFields;
|
||||
const ShaderAst::StructDescription* inputStruct = nullptr;
|
||||
|
||||
auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
|
||||
auto HandleInOutStructs = [this, shaderStage](const Identifier& identifier, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
|
||||
{
|
||||
assert(IsIdentifierType(expressionType));
|
||||
|
||||
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::IdentifierType>(expressionType).name);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier.value));
|
||||
const auto& s = std::get<ShaderAst::StructDescription>(identifier.value);
|
||||
|
||||
for (const auto& member : s.members)
|
||||
{
|
||||
@@ -426,17 +472,12 @@ namespace Nz
|
||||
};
|
||||
|
||||
if (!m_currentState->entryFunc->parameters.empty())
|
||||
{
|
||||
assert(m_currentState->entryFunc->parameters.size() == 1);
|
||||
const auto& parameter = m_currentState->entryFunc->parameters.front();
|
||||
|
||||
inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_");
|
||||
}
|
||||
inputStruct = HandleInOutStructs(entryResolver.inputIdentifier, inputFields, "in", "_nzInput.", "_NzIn_");
|
||||
|
||||
std::vector<InOutField> outputFields;
|
||||
const ShaderAst::StructDescription* outputStruct = nullptr;
|
||||
if (!IsNoType(m_currentState->entryFunc->returnType))
|
||||
outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_");
|
||||
outputStruct = HandleInOutStructs(entryResolver.outputIdentifier, outputFields, "out", "_nzOutput.", "_NzOut_");
|
||||
|
||||
if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition)
|
||||
AppendLine("uniform float ", flipYUniformName, ";");
|
||||
@@ -486,12 +527,12 @@ namespace Nz
|
||||
LeaveScope();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
||||
void GlslWriter::AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
||||
{
|
||||
Append(".");
|
||||
Append(memberIdentifier[0]);
|
||||
|
||||
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName);
|
||||
const Identifier* identifier = FindIdentifier(structName);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
@@ -503,7 +544,7 @@ namespace Nz
|
||||
const auto& member = *memberIt;
|
||||
|
||||
if (remainingMembers > 1)
|
||||
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
|
||||
AppendField(std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
|
||||
}
|
||||
|
||||
void GlslWriter::AppendLine(const std::string& txt)
|
||||
@@ -558,12 +599,10 @@ namespace Nz
|
||||
{
|
||||
Visit(node.structExpr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
|
||||
const ShaderAst::ExpressionType& exprType = node.structExpr->cachedExpressionType.value();
|
||||
assert(IsIdentifierType(exprType));
|
||||
|
||||
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
|
||||
|
||||
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
|
||||
AppendField(std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
|
||||
@@ -593,7 +632,9 @@ namespace Nz
|
||||
AppendLine(")");
|
||||
|
||||
EnterScope();
|
||||
PushScope();
|
||||
statement.statement->Visit(*this);
|
||||
PopScope();
|
||||
LeaveScope();
|
||||
|
||||
first = false;
|
||||
@@ -604,7 +645,9 @@ namespace Nz
|
||||
AppendLine("else");
|
||||
|
||||
EnterScope();
|
||||
PushScope();
|
||||
node.elseStatement->Visit(*this);
|
||||
PopScope();
|
||||
LeaveScope();
|
||||
}
|
||||
}
|
||||
@@ -698,6 +741,8 @@ namespace Nz
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
|
||||
|
||||
for (const auto& externalVar : node.externalVars)
|
||||
{
|
||||
std::optional<long long> bindingIndex;
|
||||
@@ -729,7 +774,7 @@ namespace Nz
|
||||
|
||||
EnterScope();
|
||||
{
|
||||
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
|
||||
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
@@ -780,15 +825,19 @@ namespace Nz
|
||||
Append(")\n");
|
||||
|
||||
EnterScope();
|
||||
PushScope();
|
||||
{
|
||||
for (auto& statement : node.statements)
|
||||
statement->Visit(*this);
|
||||
}
|
||||
PopScope();
|
||||
LeaveScope();
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
|
||||
{
|
||||
RegisterStruct(node.description);
|
||||
|
||||
Append("struct ");
|
||||
AppendLine(node.description.name);
|
||||
EnterScope();
|
||||
@@ -813,6 +862,8 @@ namespace Nz
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||
{
|
||||
RegisterVariable(node.varName, node.varType);
|
||||
|
||||
Append(node.varType);
|
||||
Append(" ");
|
||||
Append(node.varName);
|
||||
@@ -871,6 +922,8 @@ namespace Nz
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
|
||||
{
|
||||
PushScope();
|
||||
|
||||
bool first = true;
|
||||
for (const ShaderAst::StatementPtr& statement : node.statements)
|
||||
{
|
||||
@@ -881,6 +934,8 @@ namespace Nz
|
||||
|
||||
first = false;
|
||||
}
|
||||
|
||||
PopScope();
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)
|
||||
|
||||
Reference in New Issue
Block a user