#include #include #include #include #include #include #include #include #include #include void PropagateConstantAndExpect(std::string_view sourceCode, std::string_view expectedOptimizedResult) { Nz::ShaderAst::StatementPtr shader; REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader)); REQUIRE_NOTHROW(shader = Nz::ShaderAst::PropagateConstants(*shader)); ExpectNZSL(*shader, expectedOptimizedResult); } void EliminateUnusedAndExpect(std::string_view sourceCode, std::string_view expectedOptimizedResult) { Nz::ShaderAst::StatementPtr shader; REQUIRE_NOTHROW(shader = Nz::ShaderLang::Parse(sourceCode)); REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader)); REQUIRE_NOTHROW(shader = Nz::ShaderAst::EliminateUnusedPass(*shader)); ExpectNZSL(*shader, expectedOptimizedResult); } TEST_CASE("optimizations", "[Shader]") { WHEN("propagating constants") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let output = 8.0 * (7.0 + 5.0) * 2.0 / 4.0 - 6.0; } )", R"( [entry(frag)] fn main() { let output: f32 = 42.000000; } )"); } WHEN("propagating vector constants") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let output = vec4[f32](8.0, 2.0, -7.0, 0.0) * (7.0 + 5.0) * 2.0 / 4.0; } )", R"( [entry(frag)] fn main() { let output: vec4[f32] = vec4[f32](48.000000, 12.000000, -42.000000, 0.000000); )"); } WHEN("eliminating simple branch") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { if (5 + 3 < 2) discard; } )", R"( [entry(frag)] fn main() { } )"); } WHEN("eliminating multiple branches") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let output = 0.0; if (5 <= 3) output = 5.0; else if (4 <= 3) output = 4.0; else if (3 <= 3) output = 3.0; else if (2 <= 3) output = 2.0; else if (1 <= 3) output = 1.0; else output = 0.0; } )", R"( [entry(frag)] fn main() { let output: f32 = 0.000000; output = 3.000000; } )"); } WHEN("eliminating multiple split branches") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let output = 0.0; if (5 <= 3) output = 5.0; else { if (4 <= 3) output = 4.0; else { if (3 <= 3) output = 3.0; else { if (2 <= 3) output = 2.0; else { if (1 <= 3) output = 1.0; else output = 0.0; } } } } } )", R"( [entry(frag)] fn main() { let output: f32 = 0.000000; output = 3.000000; } )"); } WHEN("optimizing out scalar swizzle") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let value = vec3[f32](3.0, 0.0, 1.0).z; } )", R"( [entry(frag)] fn main() { let value: f32 = 1.000000; } )"); } WHEN("optimizing out scalar swizzle to vector") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let value = (42.0).xxxx; } )", R"( [entry(frag)] fn main() { let value: vec4[f32] = vec4[f32](42.000000, 42.000000, 42.000000, 42.000000); } )"); } WHEN("optimizing out vector swizzle") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let value = vec4[f32](3.0, 0.0, 1.0, 2.0).yzwx; } )", R"( [entry(frag)] fn main() { let value: vec4[f32] = vec4[f32](0.000000, 1.000000, 2.000000, 3.000000); } )"); } WHEN("optimizing out vector swizzle with repetition") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let value = vec4[f32](3.0, 0.0, 1.0, 2.0).zzxx; } )", R"( [entry(frag)] fn main() { let value: vec4[f32] = vec4[f32](1.000000, 1.000000, 3.000000, 3.000000); } )"); } WHEN("optimizing out complex swizzle") { PropagateConstantAndExpect(R"( [entry(frag)] fn main() { let value = vec4[f32](0.0, 1.0, 2.0, 3.0).xyz.yz.y.x.xxxx; } )", R"( [entry(frag)] fn main() { let value: vec4[f32] = vec4[f32](2.000000, 2.000000, 2.000000, 2.000000); } )"); } WHEN("optimizing out complex swizzle on unknown value") { PropagateConstantAndExpect(R"( struct inputStruct { value: vec4[f32] } external { [set(0), binding(0)] data: uniform[inputStruct] } [entry(frag)] fn main() { let value = data.value.xyz.yz.y.x.xxxx; } )", R"( [entry(frag)] fn main() { let value: vec4[f32] = data.value.zzzz; } )"); } WHEN("eliminating unused code") { EliminateUnusedAndExpect(R"( struct inputStruct { value: vec4[f32] } struct notUsed { value: vec4[f32] } external { [set(0), binding(0)] unusedData: uniform[notUsed], [set(0), binding(1)] data: uniform[inputStruct] } fn unusedFunction() -> vec4[f32] { return unusedData.value; } struct Output { value: vec4[f32] } [entry(frag)] fn main() -> Output { let unusedvalue = unusedFunction(); let output: Output; output.value = data.value; return output; })", R"( struct inputStruct { value: vec4[f32] } external { [set(0), binding(1)] data: uniform[inputStruct] } struct Output { value: vec4[f32] } [entry(frag)] fn main() -> Output { let output: Output; output.value = data.value; return output; })"); } }