diff --git a/examples/RenderTest/main.cpp b/examples/RenderTest/main.cpp index d74fb15b5..23a2c8cf9 100644 --- a/examples/RenderTest/main.cpp +++ b/examples/RenderTest/main.cpp @@ -19,7 +19,7 @@ struct Data external { - [[binding(0), layout(std140)]] viewerData: uniform, + [[binding(0)]] viewerData: uniform, [[binding(1)]] tex: sampler2D } diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp index a54b03395..559602038 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.hpp +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -95,6 +95,7 @@ namespace Nz::ShaderAst inline bool IsNoType(const ExpressionType& type); inline bool IsPrimitiveType(const ExpressionType& type); inline bool IsSamplerType(const ExpressionType& type); + inline bool IsStructType(const ExpressionType& type); inline bool IsUniformType(const ExpressionType& type); inline bool IsVectorType(const ExpressionType& type); } diff --git a/include/Nazara/Shader/Ast/ExpressionType.inl b/include/Nazara/Shader/Ast/ExpressionType.inl index 2fdb36674..ec9580616 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.inl +++ b/include/Nazara/Shader/Ast/ExpressionType.inl @@ -109,6 +109,11 @@ namespace Nz::ShaderAst return std::holds_alternative(type); } + bool IsStructType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + bool IsUniformType(const ExpressionType& type) { return std::holds_alternative(type); diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 7b4741cd4..c8f270b1d 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -9,7 +9,8 @@ #include #include -#include +#include +#include #include #include #include @@ -17,7 +18,7 @@ namespace Nz { - class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstScopedVisitor + class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept { public: struct Environment; @@ -59,24 +60,27 @@ namespace Nz template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); void AppendCommentSection(const std::string& section); void AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader); - void AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers); + void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); void EnterScope(); void LeaveScope(bool skipLine = true); + void RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc); + void RegisterVariable(std::size_t varIndex, std::string varName); + void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); - void Visit(ShaderAst::AccessMemberIdentifierExpression& node) override; + void Visit(ShaderAst::AccessMemberIndexExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::ConditionalExpression& node) override; void Visit(ShaderAst::ConstantExpression& node) override; - void Visit(ShaderAst::IdentifierExpression& node) override; void Visit(ShaderAst::IntrinsicExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; + void Visit(ShaderAst::VariableExpression& node) override; void Visit(ShaderAst::BranchStatement& node) override; void Visit(ShaderAst::ConditionalStatement& node) override; diff --git a/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp b/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp index 8fe618231..86b043877 100644 --- a/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp +++ b/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -122,7 +123,9 @@ namespace Nz if (!shader.Create(device, ToOpenGL(shaderStage))) throw std::runtime_error("failed to create shader"); //< TODO: Handle error message - std::string code = writer.Generate(shaderStage, shaderAst, states); + ShaderAst::AstCloner cloner; //< FIXME: Required because writer may update AST + ShaderAst::StatementPtr clonedAst = cloner.Clone(shaderAst); + std::string code = writer.Generate(shaderStage, clonedAst, states); shader.SetSource(code.data(), code.size()); shader.Compile(); diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 46c63e5f7..d3ddb8827 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -70,40 +71,6 @@ namespace Nz ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; }; - struct EntryFuncResolver : ShaderAst::AstScopedVisitor - { - void Visit(ShaderAst::DeclareFunctionStatement& node) override - { - - - if (&node != entryPoint) - return; - - assert(node.parameters.size() == 1); - - const ShaderAst::ExpressionType& inputType = node.parameters.front().type; - const ShaderAst::ExpressionType& outputType = node.returnType; - - const Identifier* identifier; - - assert(IsIdentifierType(node.parameters.front().type)); - identifier = FindIdentifier(std::get(inputType).name); - assert(identifier); - - inputIdentifier = *identifier; - - assert(IsIdentifierType(outputType)); - identifier = FindIdentifier(std::get(outputType).name); - assert(identifier); - - outputIdentifier = *identifier; - } - - Identifier inputIdentifier; - Identifier outputIdentifier; - ShaderAst::DeclareFunctionStatement* entryPoint; - }; - struct Builtin { std::string identifier; @@ -118,9 +85,17 @@ namespace Nz struct GlslWriter::State { + struct StructInfo + { + ShaderAst::StructDescription structDesc; + bool isStd140 = false; + }; + const States* states = nullptr; ShaderAst::DeclareFunctionStatement* entryFunc = nullptr; std::stringstream stream; + std::unordered_map structs; + std::unordered_map variableNames; unsigned int indentLevel = 0; }; @@ -143,10 +118,13 @@ namespace Nz if (!ShaderAst::ValidateAst(shader, &error)) throw std::runtime_error("Invalid shader AST: " + error); + ShaderAst::TransformVisitor transformVisitor; + ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader); + PreVisitor previsitor; previsitor.selectedEntryPoint = shaderStage; - ShaderAst::StatementPtr adaptedShader = previsitor.Clone(shader); + ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader); if (!previsitor.entryPoint) throw std::runtime_error("missing entry point"); @@ -233,14 +211,10 @@ namespace Nz AppendLine(); } - PushScope(); - { - adaptedShader->Visit(*this); + adaptedShader->Visit(*this); - // Append true GLSL entry point - AppendEntryPoint(shaderStage, adaptedShader); - } - PopScope(); + // Append true GLSL entry point + AppendEntryPoint(shaderStage, adaptedShader); return state.stream.str(); } @@ -275,7 +249,7 @@ namespace Nz void GlslWriter::Append(const ShaderAst::IdentifierType& identifierType) { - Append(identifierType.name); + throw std::runtime_error("unexpected identifier type"); } void GlslWriter::Append(const ShaderAst::MatrixType& matrixType) @@ -332,7 +306,8 @@ namespace Nz void GlslWriter::Append(const ShaderAst::StructType& structType) { - throw std::runtime_error("unexpected struct type"); + const auto& structDesc = m_currentState->structs[structType.structIndex].structDesc; + Append(structDesc.name); } void GlslWriter::Append(const ShaderAst::UniformType& uniformType) @@ -395,9 +370,24 @@ namespace Nz void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader) { - EntryFuncResolver entryResolver; - entryResolver.entryPoint = m_currentState->entryFunc; - entryResolver.ScopedVisit(shader); + ShaderAst::DeclareFunctionStatement& entryFunc = *m_currentState->entryFunc; + + std::optional inputStructIndex; + if (!entryFunc.parameters.empty()) + { + assert(entryFunc.parameters.size() == 1); + auto& parameter = entryFunc.parameters.front(); + assert(std::holds_alternative(parameter.type)); + + inputStructIndex = std::get(parameter.type).structIndex; + } + + std::optional outputStructIndex; + if (!IsNoType(entryFunc.returnType)) + { + assert(std::holds_alternative(entryFunc.returnType)); + outputStructIndex = std::get(entryFunc.returnType).structIndex; + } AppendLine(); AppendLine("// Entry point handling"); @@ -411,12 +401,11 @@ namespace Nz std::vector inputFields; const ShaderAst::StructDescription* inputStruct = nullptr; - auto HandleInOutStructs = [this, shaderStage](const Identifier& identifier, std::vector& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription* + auto HandleInOutStructs = [this, shaderStage](std::size_t structIndex, std::vector& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription* { - assert(std::holds_alternative(identifier.value)); - const auto& s = std::get(identifier.value); + const auto& structDesc = m_currentState->structs[structIndex].structDesc; - for (const auto& member : s.members) + for (const auto& member : structDesc.members) { bool skip = false; std::optional builtinName; @@ -474,16 +463,16 @@ namespace Nz } AppendLine(); - return &s; + return &structDesc; }; - if (!m_currentState->entryFunc->parameters.empty()) - inputStruct = HandleInOutStructs(entryResolver.inputIdentifier, inputFields, "in", "_nzInput.", "_NzIn_"); + if (inputStructIndex) + inputStruct = HandleInOutStructs(*inputStructIndex, inputFields, "in", "_nzInput.", "_NzIn_"); std::vector outputFields; const ShaderAst::StructDescription* outputStruct = nullptr; - if (!IsNoType(m_currentState->entryFunc->returnType)) - outputStruct = HandleInOutStructs(entryResolver.outputIdentifier, outputFields, "out", "_nzOutput.", "_NzOut_"); + if (outputStructIndex) + outputStruct = HandleInOutStructs(*outputStructIndex, outputFields, "out", "_nzOutput.", "_NzOut_"); if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition) AppendLine("uniform float ", flipYUniformName, ";"); @@ -533,24 +522,20 @@ namespace Nz LeaveScope(); } - void GlslWriter::AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) + void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers) { + const auto& structDesc = m_currentState->structs[structIndex].structDesc; + + const auto& member = structDesc.members[*memberIndices]; + Append("."); - Append(memberIdentifier[0]); - - const Identifier* identifier = FindIdentifier(structName); - assert(identifier); - - assert(std::holds_alternative(identifier->value)); - const auto& s = std::get(identifier->value); - - auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); - assert(memberIt != s.members.end()); - - const auto& member = *memberIt; + Append(member.name); if (remainingMembers > 1) - AppendField(std::get(member.type).name, memberIdentifier + 1, remainingMembers - 1); + { + assert(IsStructType(member.type)); + AppendField(std::get(member.type).structIndex, memberIndices + 1, remainingMembers - 1); + } } void GlslWriter::AppendLine(const std::string& txt) @@ -588,6 +573,21 @@ namespace Nz Append("}"); } + void GlslWriter::RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc) + { + assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); + m_currentState->structs.emplace(structIndex, State::StructInfo{ + std::move(desc), + isStd140 + }); + } + + void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName) + { + assert(m_currentState->variableNames.find(varIndex) == m_currentState->variableNames.end()); + m_currentState->variableNames.emplace(varIndex, std::move(varName)); + } + void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired) { bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue); @@ -601,14 +601,14 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(ShaderAst::AccessMemberIdentifierExpression& node) + void GlslWriter::Visit(ShaderAst::AccessMemberIndexExpression& node) { Visit(node.structExpr, true); - const ShaderAst::ExpressionType& exprType = node.structExpr->cachedExpressionType.value(); - assert(IsIdentifierType(exprType)); + const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr); + assert(IsStructType(exprType)); - AppendField(std::get(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size()); + AppendField(std::get(exprType).structIndex, node.memberIndices.data(), node.memberIndices.size()); } void GlslWriter::Visit(ShaderAst::AssignExpression& node) @@ -638,9 +638,7 @@ namespace Nz AppendLine(")"); EnterScope(); - PushScope(); statement.statement->Visit(*this); - PopScope(); LeaveScope(); first = false; @@ -651,9 +649,7 @@ namespace Nz AppendLine("else"); EnterScope(); - PushScope(); node.elseStatement->Visit(*this); - PopScope(); LeaveScope(); } } @@ -747,21 +743,32 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node) { + assert(node.varIndex); + std::size_t varIndex = *node.varIndex; + for (const auto& externalVar : node.externalVars) { std::optional bindingIndex; - bool isStd140 = false; for (const auto& [attributeType, attributeParam] : externalVar.attributes) { if (attributeType == ShaderAst::AttributeType::Binding) - bindingIndex = std::get(attributeParam); - else if (attributeType == ShaderAst::AttributeType::Layout) { - if (std::get(attributeParam) == "std140") - isStd140 = true; + bindingIndex = std::get(attributeParam); + break; } } + bool isStd140 = false; + if (IsUniformType(externalVar.type)) + { + auto& uniform = std::get(externalVar.type); + assert(std::holds_alternative(uniform.containedType)); + + std::size_t structIndex = std::get(uniform.containedType).structIndex; + auto& structInfo = m_currentState->structs[structIndex]; + isStd140 = structInfo.isStd140; + } + if (bindingIndex) { Append("layout(binding = "); @@ -773,19 +780,20 @@ namespace Nz if (IsUniformType(externalVar.type)) { + Append("_NzBinding_"); AppendLine(externalVar.name); EnterScope(); { - const Identifier* identifier = FindIdentifier(std::get(std::get(externalVar.type).containedType).name); - assert(identifier); + auto& uniform = std::get(externalVar.type); + assert(std::holds_alternative(uniform.containedType)); - assert(std::holds_alternative(identifier->value)); - const auto& s = std::get(identifier->value); + std::size_t structIndex = std::get(uniform.containedType).structIndex; + auto& structInfo = m_currentState->structs[structIndex]; bool first = true; - for (const auto& [name, attribute, type] : s.members) + for (const auto& [name, attribute, type] : structInfo.structDesc.members) { if (!first) AppendLine(); @@ -807,6 +815,8 @@ namespace Nz Append(externalVar.name); AppendLine(";"); } + + RegisterVariable(varIndex++, externalVar.name); } } @@ -814,6 +824,9 @@ namespace Nz { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); + assert(node.varIndex); + std::size_t varIndex = *node.varIndex; + Append(node.returnType); Append(" "); Append(node.name); @@ -825,22 +838,33 @@ namespace Nz Append(node.parameters[i].type); Append(" "); Append(node.parameters[i].name); + + RegisterVariable(varIndex++, node.parameters[i].name); } Append(")\n"); EnterScope(); - PushScope(); { for (auto& statement : node.statements) statement->Visit(*this); } - PopScope(); LeaveScope(); } void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) { - RegisterStruct(node.description); + bool isStd140 = false; + for (const auto& [attributeType, attributeParam] : node.attributes) + { + if (attributeType == ShaderAst::AttributeType::Layout && std::get(attributeParam) == "std140") + { + isStd140 = true; + break; + } + } + + assert(node.structIndex); + RegisterStruct(*node.structIndex, isStd140, node.description); Append("struct "); AppendLine(node.description.name); @@ -866,7 +890,8 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node) { - RegisterVariable(node.varName, node.varType); + assert(node.varIndex); + RegisterVariable(*node.varIndex, node.varName); Append(node.varType); Append(" "); @@ -891,11 +916,6 @@ namespace Nz AppendLine(";"); } - void GlslWriter::Visit(ShaderAst::IdentifierExpression& node) - { - Append(node.identifier); - } - void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node) { switch (node.intrinsic) @@ -926,8 +946,6 @@ namespace Nz void GlslWriter::Visit(ShaderAst::MultiStatement& node) { - PushScope(); - bool first = true; for (const ShaderAst::StatementPtr& statement : node.statements) { @@ -938,8 +956,6 @@ namespace Nz first = false; } - - PopScope(); } void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/) @@ -987,6 +1003,12 @@ namespace Nz } } + void GlslWriter::Visit(ShaderAst::VariableExpression& node) + { + const std::string& varName = m_currentState->variableNames[node.variableId]; + Append(varName); + } + bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader) { /*for (const auto& uniform : shader.GetUniforms()) diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 8f998341e..98ecb65d8 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -567,7 +567,6 @@ namespace Nz::ShaderAst for (const auto& extVar : node.externalVars) { bool hasBinding = false; - bool hasLayout = false; for (const auto& [attributeType, arg] : extVar.attributes) { switch (attributeType) @@ -588,21 +587,6 @@ namespace Nz::ShaderAst break; } - case AttributeType::Layout: - { - if (hasLayout) - throw AstError{ "attribute layout must be present once" }; - - if (!std::holds_alternative(arg)) - throw AstError{ "attribute layout requires a string parameter" }; - - if (std::get(arg) != "std140") - throw AstError{ "unknown layout type" }; - - hasLayout = true; - break; - } - default: throw AstError{ "unhandled attribute for external variable" }; } diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 8ef74aef7..82a1d360c 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -238,7 +238,7 @@ namespace Nz Int32(memberIndex), m_constantCache.Register(*m_constantCache.BuildType(member.type)), varId - }); + }); } memberIndex++;