From 2f64e493debba881b45d1636999476c9ee921326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Tue, 28 Dec 2021 11:49:53 +0100 Subject: [PATCH] UnitTests/Shader: Add sanitization and optimizations unit tests --- tests/Engine/Shader/Optimizations.cpp | 145 ++++++++++++++++++++++++++ tests/Engine/Shader/Sanitizations.cpp | 81 ++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 tests/Engine/Shader/Optimizations.cpp create mode 100644 tests/Engine/Shader/Sanitizations.cpp diff --git a/tests/Engine/Shader/Optimizations.cpp b/tests/Engine/Shader/Optimizations.cpp new file mode 100644 index 000000000..73b19192d --- /dev/null +++ b/tests/Engine/Shader/Optimizations.cpp @@ -0,0 +1,145 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void ExpectOptimization(std::string_view sourceCode, std::string_view expectedOptimizedResult) +{ + Nz::ShaderAst::StatementPtr shader; + REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); + REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader)); + REQUIRE_NOTHROW(shader = Nz::ShaderAst::Optimize(*shader)); + + ExpectNZSL(*shader, expectedOptimizedResult); +} + +TEST_CASE("optimizations", "[Shader]") +{ + WHEN("propaging constants") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let output = 8.0 * (7.0 + 5.0) * 2.0 / 4.0 - 6.0; +} +)", R"( +[entry(frag)] +fn main() +{ + let output: f32 = 42.000000; +)"); + } + + WHEN("propaging vector constants") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let output = vec4(8.0, 2.0, -7.0, 0.0) * (7.0 + 5.0) * 2.0 / 4.0; +} +)", R"( +[entry(frag)] +fn main() +{ + let output: vec4 = vec4(48.000000, 12.000000, -42.000000, 0.000000); +)"); + } + + WHEN("eliminating simple branch") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + if (5 + 3 < 2) + discard; +} +)", R"( +[entry(frag)] +fn main() +{ + +} +)"); + } + + WHEN("eliminating multiple branches") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let output = 0.0; + if (5 <= 3) + output = 5.0; + else if (4 <= 3) + output = 4.0; + else if (3 <= 3) + output = 3.0; + else if (2 <= 3) + output = 2.0; + else if (1 <= 3) + output = 1.0; + else + output = 0.0; +} +)", R"( +[entry(frag)] +fn main() +{ + let output: f32 = 0.000000; + output = 3.000000; +} +)"); + } + + + WHEN("eliminating multiple splitted branches") + { + ExpectOptimization(R"( +[entry(frag)] +fn main() +{ + let output = 0.0; + if (5 <= 3) + output = 5.0; + else + { + if (4 <= 3) + output = 4.0; + else + { + if (3 <= 3) + output = 3.0; + else + { + if (2 <= 3) + output = 2.0; + else + { + if (1 <= 3) + output = 1.0; + else + output = 0.0; + } + } + } + } +} +)", R"( +[entry(frag)] +fn main() +{ + let output: f32 = 0.000000; + output = 3.000000; +} +)"); + } +} diff --git a/tests/Engine/Shader/Sanitizations.cpp b/tests/Engine/Shader/Sanitizations.cpp new file mode 100644 index 000000000..9a376c0a8 --- /dev/null +++ b/tests/Engine/Shader/Sanitizations.cpp @@ -0,0 +1,81 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +TEST_CASE("sanitizing", "[Shader]") +{ + WHEN("splitting branches") + { + std::string_view nzslSource = R"( +struct inputStruct +{ + value: f32 +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let value: f32; + if (data.value > 3.0) + value = 3.0; + else if (data.value > 2.0) + value = 2.0; + else if (data.value > 1.0) + value = 1.0; + else + value = 0.0; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + Nz::ShaderAst::SanitizeVisitor::Options options; + options.splitMultipleBranches = true; + + REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options)); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let value: f32; + if (data.value > (3.000000)) + { + value = 3.000000; + } + else + { + if (data.value > (2.000000)) + { + value = 2.000000; + } + else + { + if (data.value > (1.000000)) + { + value = 1.000000; + } + else + { + value = 0.000000; + } + + } + + } + +} +)"); + + } +}