Fix shader generation unit tests

This commit is contained in:
Jérôme Leclercq 2021-06-16 15:46:14 +02:00
parent 298beaedc0
commit dfa46ebaa5
6 changed files with 41 additions and 50 deletions

View File

@ -31,8 +31,8 @@ namespace Nz
GlslWriter(GlslWriter&&) = delete; GlslWriter(GlslWriter&&) = delete;
~GlslWriter() = default; ~GlslWriter() = default;
inline std::string Generate(ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states = {}); inline std::string Generate(ShaderAst::Statement& shader, const BindingMapping& bindingMapping = {}, const States& states = {});
std::string Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states = {}); std::string Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping = {}, const States& states = {});
void SetEnv(Environment environment); void SetEnv(Environment environment);

View File

@ -707,11 +707,8 @@ namespace Nz::ShaderAst
if (!extVar.bindingIndex) if (!extVar.bindingIndex)
throw AstError{ "external variable " + extVar.name + " requires a binding index" }; throw AstError{ "external variable " + extVar.name + " requires a binding index" };
if (!extVar.bindingSet)
throw AstError{ "external variable " + extVar.name + " requires a binding set" };
UInt64 bindingIndex = *extVar.bindingIndex; UInt64 bindingIndex = *extVar.bindingIndex;
UInt64 bindingSet = *extVar.bindingSet; UInt64 bindingSet = extVar.bindingSet.value_or(0);
UInt64 bindingKey = bindingSet << 32 | bindingIndex; UInt64 bindingKey = bindingSet << 32 | bindingIndex;
if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end()) if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" }; throw AstError{ "Binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" };

View File

@ -52,22 +52,6 @@ namespace Nz
node.statement->Visit(*this); node.statement->Visit(*this);
} }
void Visit(ShaderAst::DeclareExternalStatement& node) override
{
AstRecursiveVisitor::Visit(node);
for (auto& extVar : node.externalVars)
{
assert(extVar.bindingIndex);
assert(extVar.bindingSet);
UInt64 set = *extVar.bindingSet;
UInt64 binding = *extVar.bindingIndex;
bindings.insert(set << 32 | binding);
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override void Visit(ShaderAst::DeclareFunctionStatement& node) override
{ {
// Dismiss function if it's an entry point of another type than the one selected // Dismiss function if it's an entry point of another type than the one selected
@ -113,7 +97,6 @@ namespace Nz
FunctionData* currentFunction = nullptr; FunctionData* currentFunction = nullptr;
std::set<UInt64 /*set | binding*/> bindings;
std::optional<ShaderStageType> selectedStage; std::optional<ShaderStageType> selectedStage;
std::unordered_map<std::size_t, FunctionData> functions; std::unordered_map<std::size_t, FunctionData> functions;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
@ -867,21 +850,32 @@ namespace Nz
isStd140 = structInfo.layout == StructLayout::Std140; isStd140 = structInfo.layout == StructLayout::Std140;
} }
if (!m_currentState->bindingMapping.empty() || isStd140)
Append("layout(");
if (!m_currentState->bindingMapping.empty())
{
assert(externalVar.bindingIndex); assert(externalVar.bindingIndex);
assert(externalVar.bindingSet);
UInt64 bindingIndex = *externalVar.bindingIndex; UInt64 bindingIndex = *externalVar.bindingIndex;
UInt64 bindingSet = *externalVar.bindingSet; UInt64 bindingSet = externalVar.bindingSet.value_or(0);
auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex); auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex);
if (bindingIt == m_currentState->bindingMapping.end()) if (bindingIt == m_currentState->bindingMapping.end())
throw std::runtime_error("no binding found for (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ")"); throw std::runtime_error("no binding found for (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ")");
Append("layout(binding = ", bindingIt->second); Append("binding = ", bindingIt->second);
if (isStd140) if (isStd140)
Append(", std140"); Append(", ");
}
Append(") uniform "); if (isStd140)
Append("std140");
if (!m_currentState->bindingMapping.empty() || isStd140)
Append(") ");
Append("uniform ");
if (IsUniformType(externalVar.type)) if (IsUniformType(externalVar.type))
{ {

View File

@ -482,8 +482,8 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon); Expect(Advance(), TokenType::Colon);
extVar.type = ParseType(); extVar.type = ParseType();
if (!extVar.bindingSet) if (!extVar.bindingSet && blockSetIndex)
extVar.bindingSet = blockSetIndex.value_or(0); extVar.bindingSet = *blockSetIndex;
RegisterVariable(extVar.name); RegisterVariable(extVar.name);
} }

View File

@ -46,8 +46,8 @@ namespace Nz
public: public:
struct UniformVar struct UniformVar
{ {
std::optional<UInt32> bindingIndex; UInt32 bindingIndex;
std::optional<UInt32> descriptorSet; UInt32 descriptorSet;
UInt32 pointerId; UInt32 pointerId;
}; };
@ -123,10 +123,12 @@ namespace Nz
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform; variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass); variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
assert(extVar.bindingIndex);
UniformVar& uniformVar = extVars[varIndex++]; UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable); uniformVar.pointerId = m_constantCache.Register(variable);
uniformVar.bindingIndex = extVar.bindingIndex; uniformVar.bindingIndex = *extVar.bindingIndex;
uniformVar.descriptorSet = extVar.bindingSet; uniformVar.descriptorSet = extVar.bindingSet.value_or(0);
} }
} }
@ -491,11 +493,8 @@ namespace Nz
for (auto&& [varIndex, extVar] : preVisitor.extVars) for (auto&& [varIndex, extVar] : preVisitor.extVars)
{ {
if (extVar.bindingIndex) state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
{ state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, *extVar.descriptorSet);
}
} }
for (auto&& [varId, builtin] : preVisitor.builtinDecorations) for (auto&& [varId, builtin] : preVisitor.builtinDecorations)

View File

@ -7,7 +7,7 @@
#include <catch2/catch.hpp> #include <catch2/catch.hpp>
#include <cctype> #include <cctype>
void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput) void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
{ {
Nz::GlslWriter writer; Nz::GlslWriter writer;
@ -19,7 +19,7 @@ void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expecte
REQUIRE(subset == expectedOutput); REQUIRE(subset == expectedOutput);
} }
void ExpectingSpirV(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput) void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
{ {
Nz::SpirvWriter writer; Nz::SpirvWriter writer;
auto spirv = writer.Generate(shader); auto spirv = writer.Generate(shader);
@ -64,6 +64,7 @@ SCENARIO("Shader generation", "[Shader]")
auto external = std::make_unique<Nz::ShaderAst::DeclareExternalStatement>(); auto external = std::make_unique<Nz::ShaderAst::DeclareExternalStatement>();
external->externalVars.push_back({ external->externalVars.push_back({
0,
std::nullopt, std::nullopt,
"ubo", "ubo",
Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } } Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } }
@ -85,7 +86,7 @@ SCENARIO("Shader generation", "[Shader]")
SECTION("Generating GLSL") SECTION("Generating GLSL")
{ {
ExpectingGLSL(shader, R"( ExpectingGLSL(*shader, R"(
void main() void main()
{ {
float result = ubo.s.field.z; float result = ubo.s.field.z;
@ -94,7 +95,7 @@ void main()
} }
SECTION("Generating Spir-V") SECTION("Generating Spir-V")
{ {
ExpectingSpirV(shader, R"( ExpectingSpirV(*shader, R"(
OpFunction OpFunction
OpLabel OpLabel
OpVariable OpVariable
@ -121,7 +122,7 @@ OpFunctionEnd)");
SECTION("Generating GLSL") SECTION("Generating GLSL")
{ {
ExpectingGLSL(shader, R"( ExpectingGLSL(*shader, R"(
void main() void main()
{ {
float result = ubo.s.field.z; float result = ubo.s.field.z;
@ -130,7 +131,7 @@ void main()
} }
SECTION("Generating Spir-V") SECTION("Generating Spir-V")
{ {
ExpectingSpirV(shader, R"( ExpectingSpirV(*shader, R"(
OpFunction OpFunction
OpLabel OpLabel
OpVariable OpVariable