diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index fe6ab7fb9..1346e15c0 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -376,7 +376,8 @@ namespace Nz SpirvBlock mergeBlock(m_writer); - previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice + if (!previousContentBlock.IsTerminated()) + previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); m_functionBlocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None); @@ -397,7 +398,8 @@ namespace Nz statement.statement->Visit(*this); - previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice + if (!previousContentBlock.IsTerminated()) + previousContentBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); } if (node.elseStatement) @@ -408,7 +410,8 @@ namespace Nz node.elseStatement->Visit(*this); - elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice + if (!elseBlock.IsTerminated()) + elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); m_functionBlocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId()); m_functionBlocks.emplace_back(std::move(previousContentBlock)); diff --git a/tests/Engine/Shader/Branch.cpp b/tests/Engine/Shader/Branch.cpp index 0f28f3bac..269fb3ba9 100644 --- a/tests/Engine/Shader/Branch.cpp +++ b/tests/Engine/Shader/Branch.cpp @@ -8,7 +8,9 @@ TEST_CASE("branching", "[Shader]") { - std::string_view nzslSource = R"( + WHEN("using a simple branch") + { + std::string_view nzslSource = R"( struct inputStruct { value: f32 @@ -30,9 +32,9 @@ fn main() } )"; - Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); - ExpectGLSL(*shader, R"( + ExpectGLSL(*shader, R"( void main() { float value; @@ -48,7 +50,7 @@ void main() } )"); - ExpectNZSL(*shader, R"( + ExpectNZSL(*shader, R"( [entry(frag)] fn main() { @@ -65,7 +67,7 @@ fn main() } )"); - ExpectSpirV(*shader, R"( + ExpectSpirV(*shader, R"( OpFunction OpLabel OpVariable @@ -83,4 +85,66 @@ OpBranch OpLabel OpReturn OpFunctionEnd)"); + } + + WHEN("discarding in a branch") + { + std::string_view nzslSource = R"( +struct inputStruct +{ + value: f32 +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + if (data.value > 0.0) + discard; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + if (data.value > (0.000000)) + { + discard; + } + +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + if (data.value > (0.000000)) + { + discard; + } + +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpAccessChain +OpLoad +OpFOrdGreaterThanEqual +OpSelectionMerge +OpBranchConditional +OpLabel +OpKill +OpLabel +OpReturn +OpFunctionEnd)"); + } } diff --git a/tests/Engine/Shader/ShaderUtils.cpp b/tests/Engine/Shader/ShaderUtils.cpp index 4af6dab3b..8d170c8b9 100644 --- a/tests/Engine/Shader/ShaderUtils.cpp +++ b/tests/Engine/Shader/ShaderUtils.cpp @@ -137,8 +137,8 @@ void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutpu WHEN("Validating full GLSL code (using glslang)") { - glslang::TShader shader(EShLangVertex); - shader.setEnvInput(glslang::EShSourceGlsl, EShLangVertex, glslang::EShClientOpenGL, 300); + glslang::TShader shader(EShLangFragment); + shader.setEnvInput(glslang::EShSourceGlsl, EShLangFragment, glslang::EShClientOpenGL, 300); shader.setEnvClient(glslang::EShClientOpenGL, glslang::EShTargetOpenGL_450); shader.setEnvTarget(glslang::EShTargetNone, static_cast(0)); shader.setEntryPoint("main");