diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index ffc19987d..99e68d9f0 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -29,7 +29,8 @@ namespace Nz GlslWriter(GlslWriter&&) = delete; ~GlslWriter() = default; - std::string Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions = {}); + std::string Generate(ShaderAst::StatementPtr& shader, const States& conditions = {}); + std::string Generate(std::optional shaderStage, ShaderAst::StatementPtr& shader, const States& conditions = {}); void SetEnv(Environment environment); diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index e80bd1c3a..a509245e4 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -43,19 +43,27 @@ namespace Nz ShaderAst::DeclareFunctionStatement* func = static_cast(clone.get()); // Remove function if it's an entry point of another type than the one selected - if (node.entryStage) + if (selectedStage) { - ShaderStageType stage = *node.entryStage; - if (stage != selectedStage) - return ShaderBuilder::NoOp(); + if (node.entryStage) + { + ShaderStageType stage = *node.entryStage; + if (stage != *selectedStage) + return ShaderBuilder::NoOp(); + entryPoint = func; + } + } + else + { + assert(!entryPoint); entryPoint = func; } return clone; } - ShaderStageType selectedStage; + std::optional selectedStage; ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; }; @@ -79,7 +87,7 @@ namespace Nz std::string targetName; }; - ShaderStageType stage; + std::optional stage; const States* states = nullptr; ShaderAst::DeclareFunctionStatement* entryFunc = nullptr; std::stringstream stream; @@ -97,7 +105,12 @@ namespace Nz { } - std::string GlslWriter::Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions) + std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions) + { + return Generate(std::nullopt, shader, conditions); + } + + std::string GlslWriter::Generate(std::optional shaderStage, ShaderAst::StatementPtr& shader, const States& conditions) { State state; state.stage = shaderStage; @@ -394,7 +407,7 @@ namespace Nz assert(it != s_builtinMapping.end()); const Builtin& builtin = it->second; - if (!builtin.stageFlags.Test(m_currentState->stage)) + if (m_currentState->stage && !builtin.stageFlags.Test(*m_currentState->stage)) continue; //< This builtin is not active in this stage, skip it fields.push_back({ @@ -852,8 +865,6 @@ namespace Nz { bool isOutputPosition = (m_currentState->stage == ShaderStageType::Vertex && m_environment.flipYPosition && targetName == "gl_Position"); - AppendLine(); - Append(targetName, " = ", outputStructVarName, ".", name); if (isOutputPosition) Append(" * vec4(1.0, ", s_flipYUniformName, ", 1.0, 1.0)"); @@ -861,7 +872,6 @@ namespace Nz AppendLine(";"); } - AppendLine(); Append("return;"); //< TODO: Don't return if it's the last statement of the function } else diff --git a/tests/Engine/Shader/AccessMember.cpp b/tests/Engine/Shader/AccessMember.cpp index 7ed92bd9c..104c76a9d 100644 --- a/tests/Engine/Shader/AccessMember.cpp +++ b/tests/Engine/Shader/AccessMember.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -18,7 +17,7 @@ std::string_view Trim(std::string_view str) return str; } -void ExpectingGLSL(const Nz::ShaderAst& shader, std::string_view expectedOutput) +void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput) { Nz::GlslWriter writer; @@ -30,7 +29,7 @@ void ExpectingGLSL(const Nz::ShaderAst& shader, std::string_view expectedOutput) REQUIRE(subset == expectedOutput); } -void ExpectingSpirV(const Nz::ShaderAst& shader, std::string_view expectedOutput) +void ExpectingSpirV(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput) { Nz::SpirvWriter writer; auto spirv = writer.Generate(shader); @@ -53,42 +52,53 @@ SCENARIO("Shader generation", "[Shader]") { SECTION("Nested member loading") { - Nz::ShaderAst baseShader(Nz::ShaderStageType::Vertex); - baseShader.AddStruct("innerStruct", { - { - "field", - Nz::ShaderNodes::BasicType::Float3 - } - }); - - baseShader.AddStruct("outerStruct", { - { - "s", - "innerStruct" - } - }); + std::vector statements; - baseShader.AddUniform("ubo", "outerStruct"); - baseShader.AddOutput("result", Nz::ShaderNodes::BasicType::Float1); + Nz::ShaderAst::StructDescription innerStructDesc; + { + innerStructDesc.name = "innerStruct"; + auto& member = innerStructDesc.members.emplace_back(); + member.name = "field"; + member.type = Nz::ShaderAst::VectorType{ 3, Nz::ShaderAst::PrimitiveType::Float32 }; + } + statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(innerStructDesc))); + + Nz::ShaderAst::StructDescription outerStruct; + { + outerStruct.name = "outerStruct"; + auto& member = outerStruct.members.emplace_back(); + member.name = "s"; + member.type = Nz::ShaderAst::IdentifierType{ "innerStruct" }; + } + statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(outerStruct))); + + auto external = std::make_unique(); + external->externalVars.push_back({ + std::nullopt, + "ubo", + Nz::ShaderAst::IdentifierType{ "outerStruct" } + }); + statements.push_back(std::move(external)); SECTION("Nested AccessMember") { - Nz::ShaderAst shader = baseShader; + auto ubo = Nz::ShaderBuilder::Identifier("ubo"); + auto firstAccess = Nz::ShaderBuilder::AccessMember(std::move(ubo), { "s" }); + auto secondAccess = Nz::ShaderBuilder::AccessMember(std::move(firstAccess), { "field" }); - auto uniform = Nz::ShaderBuilder::Uniform("ubo", "outerStruct"); - auto output = Nz::ShaderBuilder::Output("result", Nz::ShaderNodes::BasicType::Float1); + auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(secondAccess), { Nz::ShaderAst::SwizzleComponent::Third }); + auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle)); - auto access = Nz::ShaderBuilder::Swizzle(Nz::ShaderBuilder::AccessMember(Nz::ShaderBuilder::AccessMember(Nz::ShaderBuilder::Identifier(uniform), 0, "innerStruct"), 0, Nz::ShaderNodes::BasicType::Float3), Nz::ShaderNodes::SwizzleComponent::Third); - auto assign = Nz::ShaderBuilder::Assign(Nz::ShaderBuilder::Identifier(output), access); + statements.push_back(Nz::ShaderBuilder::DeclareFunction("main", std::move(varDecl))); - shader.AddFunction("main", Nz::ShaderBuilder::ExprStatement(assign)); + Nz::ShaderAst::StatementPtr shader = Nz::ShaderBuilder::MultiStatement(std::move(statements)); SECTION("Generating GLSL") { ExpectingGLSL(shader, R"( void main() { - result = ubo.s.field.z; + float result = ubo.s.field.z; } )"); } @@ -97,6 +107,7 @@ void main() ExpectingSpirV(shader, R"( OpFunction OpLabel +OpVariable OpAccessChain OpAccessChain OpLoad @@ -109,22 +120,22 @@ OpFunctionEnd)"); SECTION("AccessMember with multiples fields") { - Nz::ShaderAst shader = baseShader; + auto ubo = Nz::ShaderBuilder::Identifier("ubo"); + auto access = Nz::ShaderBuilder::AccessMember(std::move(ubo), { "s", "field" }); - auto uniform = Nz::ShaderBuilder::Uniform("ubo", "outerStruct"); - auto output = Nz::ShaderBuilder::Output("result", Nz::ShaderNodes::BasicType::Float1); + auto swizzle = Nz::ShaderBuilder::Swizzle(std::move(access), { Nz::ShaderAst::SwizzleComponent::Third }); + auto varDecl = Nz::ShaderBuilder::DeclareVariable("result", Nz::ShaderAst::PrimitiveType::Float32, std::move(swizzle)); - auto access = Nz::ShaderBuilder::Swizzle(Nz::ShaderBuilder::AccessMember(Nz::ShaderBuilder::Identifier(uniform), std::vector{ 0, 0 }, Nz::ShaderNodes::BasicType::Float3), Nz::ShaderNodes::SwizzleComponent::Third); - auto assign = Nz::ShaderBuilder::Assign(Nz::ShaderBuilder::Identifier(output), access); + statements.push_back(Nz::ShaderBuilder::DeclareFunction("main", std::move(varDecl))); - shader.AddFunction("main", Nz::ShaderBuilder::ExprStatement(assign)); + Nz::ShaderAst::StatementPtr shader = Nz::ShaderBuilder::MultiStatement(std::move(statements)); SECTION("Generating GLSL") { ExpectingGLSL(shader, R"( void main() { - result = ubo.s.field.z; + float result = ubo.s.field.z; } )"); } @@ -133,6 +144,7 @@ void main() ExpectingSpirV(shader, R"( OpFunction OpLabel +OpVariable OpAccessChain OpLoad OpCompositeExtract