diff --git a/tests/Engine/Shader/AccessMemberTest.cpp b/tests/Engine/Shader/AccessMemberTest.cpp index c29e985df..f53823db1 100644 --- a/tests/Engine/Shader/AccessMemberTest.cpp +++ b/tests/Engine/Shader/AccessMemberTest.cpp @@ -6,7 +6,7 @@ #include #include -SCENARIO("Shader generation", "[Shader]") +TEST_CASE("structure member access", "[Shader]") { SECTION("Nested member loading") { @@ -42,14 +42,14 @@ external multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); - ExpectingGLSL(*shader, R"( + ExpectGLSL(*shader, R"( void main() { float result = ubo.s.field.z; } )"); - ExpectingNZSL(*shader, R"( + ExpectNZSL(*shader, R"( [entry(vert)] fn main() { @@ -57,7 +57,7 @@ fn main() } )"); - ExpectingSpirV(*shader, R"( + ExpectSpirV(*shader, R"( OpFunction OpLabel OpVariable @@ -79,14 +79,14 @@ OpFunctionEnd)"); multiStatement.statements.push_back(Nz::ShaderBuilder::DeclareFunction(Nz::ShaderStageType::Vertex, "main", std::move(varDecl))); - ExpectingGLSL(*shader, R"( + ExpectGLSL(*shader, R"( void main() { float result = ubo.s.field.z; } )"); - ExpectingNZSL(*shader, R"( + ExpectNZSL(*shader, R"( [entry(vert)] fn main() { @@ -94,7 +94,7 @@ fn main() } )"); - ExpectingSpirV(*shader, R"( + ExpectSpirV(*shader, R"( OpFunction OpLabel OpVariable diff --git a/tests/Engine/Shader/Branch.cpp b/tests/Engine/Shader/Branch.cpp new file mode 100644 index 000000000..0f28f3bac --- /dev/null +++ b/tests/Engine/Shader/Branch.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include +#include +#include + +TEST_CASE("branching", "[Shader]") +{ + std::string_view nzslSource = R"( +struct inputStruct +{ + value: f32 +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let value: f32; + if (data.value > 0.0) + value = 1.0; + else + value = 0.0; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + float value; + if (data.value > (0.000000)) + { + value = 1.000000; + } + else + { + value = 0.000000; + } + +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let value: f32; + if (data.value > (0.000000)) + { + value = 1.000000; + } + else + { + value = 0.000000; + } + +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpAccessChain +OpLoad +OpFOrdGreaterThanEqual +OpSelectionMerge +OpBranchConditional +OpLabel +OpStore +OpBranch +OpLabel +OpStore +OpBranch +OpLabel +OpReturn +OpFunctionEnd)"); +} diff --git a/tests/Engine/Shader/Loops.cpp b/tests/Engine/Shader/Loops.cpp new file mode 100644 index 000000000..b40aa58d3 --- /dev/null +++ b/tests/Engine/Shader/Loops.cpp @@ -0,0 +1,90 @@ +#include +#include +#include +#include +#include +#include +#include + +TEST_CASE("loops", "[Shader]") +{ + std::string_view nzslSource = R"( +struct inputStruct +{ + value: f32 +} + +external +{ + [set(0), binding(0)] data: uniform +} + +[entry(frag)] +fn main() +{ + let value = 0.0; + let i = 0; + while (i < 10) + { + value += 0.1; + i += 1; + } +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + float value = 0.000000; + int i = 0; + while (i < (10)) + { + value += 0.100000; + i += 1; + } + +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let value: f32 = 0.000000; + let i: i32 = 0; + while (i < (10)) + { + value += 0.100000; + i += 1; + } + +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpStore +OpStore +OpBranch +OpLabel +OpLoad +OpSLessThan +OpLoopMerge +OpBranchConditional +OpLabel +OpLoad +OpFAdd +OpStore +OpLoad +OpIAdd +OpStore +OpBranch +OpLabel +OpReturn +OpFunctionEnd)"); +} diff --git a/tests/Engine/Shader/ShaderUtils.cpp b/tests/Engine/Shader/ShaderUtils.cpp index 92b6fd4b8..4af6dab3b 100644 --- a/tests/Engine/Shader/ShaderUtils.cpp +++ b/tests/Engine/Shader/ShaderUtils.cpp @@ -120,14 +120,13 @@ namespace }; } -void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) +void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) { expectedOutput = Nz::Trim(expectedOutput); - Nz::GlslWriter writer; - SECTION("Generating GLSL") { + Nz::GlslWriter writer; std::string output = writer.Generate(shader); WHEN("Validating expected code") @@ -156,14 +155,13 @@ void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOu } } -void ExpectingNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) +void ExpectNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) { expectedOutput = Nz::Trim(expectedOutput); - Nz::LangWriter writer; - SECTION("Generating NZSL") { + Nz::LangWriter writer; std::string output = writer.Generate(shader); WHEN("Validating expected code") @@ -180,19 +178,19 @@ void ExpectingNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOu } } -void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput) +void ExpectSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput, bool outputParameter) { expectedOutput = Nz::Trim(expectedOutput); - Nz::SpirvWriter writer; - Nz::SpirvPrinter printer; - - Nz::SpirvPrinter::Settings settings; - settings.printHeader = false; - settings.printParameters = false; - SECTION("Generating SPIRV") { + Nz::SpirvWriter writer; + Nz::SpirvPrinter printer; + + Nz::SpirvPrinter::Settings settings; + settings.printHeader = false; + settings.printParameters = outputParameter; + auto spirv = writer.Generate(shader); std::string output = printer.Print(spirv.data(), spirv.size(), settings); diff --git a/tests/Engine/Shader/ShaderUtils.hpp b/tests/Engine/Shader/ShaderUtils.hpp index 312135a7b..023c5d8ad 100644 --- a/tests/Engine/Shader/ShaderUtils.hpp +++ b/tests/Engine/Shader/ShaderUtils.hpp @@ -6,8 +6,8 @@ #include #include -void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput); -void ExpectingNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput); -void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput); +void ExpectGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput); +void ExpectNZSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput); +void ExpectSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput, bool outputParameter = false); #endif diff --git a/tests/Engine/Shader/Swizzle.cpp b/tests/Engine/Shader/Swizzle.cpp new file mode 100644 index 000000000..7bfe97121 --- /dev/null +++ b/tests/Engine/Shader/Swizzle.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include +#include +#include +#include + +TEST_CASE("swizzle", "[Shader]") +{ + SECTION("Simple swizzle") + { + WHEN("reading") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let vec = vec4(0.0, 1.0, 2.0, 3.0); + let value = vec.xyz; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + vec4 vec = vec4(0.000000, 1.000000, 2.000000, 3.000000); + vec3 value = vec.xyz; +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let vec: vec4 = vec4(0.000000, 1.000000, 2.000000, 3.000000); + let value: vec3 = vec.xyz; +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpCompositeConstruct +OpStore +OpLoad +OpVectorShuffle +OpStore +OpReturn +OpFunctionEnd)"); + } + + WHEN("writing") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let vec = vec4(0.0, 0.0, 0.0, 0.0); + vec.yzw = vec3(1.0, 2.0, 3.0); +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + vec4 vec = vec4(0.000000, 0.000000, 0.000000, 0.000000); + vec.yzw = vec3(1.000000, 2.000000, 3.000000); +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let vec: vec4 = vec4(0.000000, 0.000000, 0.000000, 0.000000); + vec.yzw = vec3(1.000000, 2.000000, 3.000000); +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpCompositeConstruct +OpStore +OpCompositeConstruct +OpLoad +OpVectorShuffle +OpStore +OpReturn +OpFunctionEnd)"); + } + } + + SECTION("Scalar swizzle") + { + GIVEN("a variable") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let value = 42.0; + let vec = value.xxx; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + float value = 42.000000; + vec3 vec = vec3(value, value, value); +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let value: f32 = 42.000000; + let vec: vec3 = value.xxx; +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpStore +OpLoad +OpCompositeConstruct +OpStore +OpReturn +OpFunctionEnd)"); + } + + GIVEN("a function value") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let vec = max(2.0, 1.0).xxx; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + float cachedResult = max(2.000000, 1.000000); + vec3 vec = vec3(cachedResult, cachedResult, cachedResult); +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let vec: vec3 = (max(2.000000, 1.000000)).xxx; +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpExtInst +OpCompositeConstruct +OpStore +OpReturn +OpFunctionEnd)"); + } + } + + SECTION("Complex swizzle") + { + WHEN("reading") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let vec = vec4(0.0, 1.0, 2.0, 3.0); + let value = vec.xyz.yz.y.x.xxxx; +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + vec4 vec = vec4(0.000000, 1.000000, 2.000000, 3.000000); + vec4 value = vec4(vec.xyz.yz.y, vec.xyz.yz.y, vec.xyz.yz.y, vec.xyz.yz.y); +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let vec: vec4 = vec4(0.000000, 1.000000, 2.000000, 3.000000); + let value: vec4 = vec.xyz.yz.y.x.xxxx; +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpVariable +OpCompositeConstruct +OpStore +OpLoad +OpVectorShuffle +OpVectorShuffle +OpCompositeExtract +OpCompositeConstruct +OpStore +OpReturn +OpFunctionEnd)"); + } + + WHEN("writing") + { + std::string_view nzslSource = R"( +[entry(frag)] +fn main() +{ + let vec = vec4(0.0, 1.0, 2.0, 3.0); + vec.wyxz.bra.ts.x = 0.0; + vec.zyxw.ar.xy.yx = vec2(1.0, 0.0); +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + ExpectGLSL(*shader, R"( +void main() +{ + vec4 vec = vec4(0.000000, 1.000000, 2.000000, 3.000000); + vec.wyxz.zxw.yx.x = 0.000000; + vec.zyxw.wx.xy.yx = vec2(1.000000, 0.000000); +} +)"); + + ExpectNZSL(*shader, R"( +[entry(frag)] +fn main() +{ + let vec: vec4 = vec4(0.000000, 1.000000, 2.000000, 3.000000); + vec.wyxz.zxw.yx.x = 0.000000; + vec.zyxw.wx.xy.yx = vec2(1.000000, 0.000000); +} +)"); + + ExpectSpirV(*shader, R"( +OpFunction +OpLabel +OpVariable +OpCompositeConstruct +OpStore +OpAccessChain +OpStore +OpCompositeConstruct +OpLoad +OpVectorShuffle +OpStore +OpReturn +OpFunctionEnd)"); + } + } +}