Fix shader generation unit tests
This commit is contained in:
parent
298beaedc0
commit
dfa46ebaa5
|
|
@ -31,8 +31,8 @@ namespace Nz
|
|||
GlslWriter(GlslWriter&&) = delete;
|
||||
~GlslWriter() = default;
|
||||
|
||||
inline std::string Generate(ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states = {});
|
||||
std::string Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states = {});
|
||||
inline std::string Generate(ShaderAst::Statement& shader, const BindingMapping& bindingMapping = {}, const States& states = {});
|
||||
std::string Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping = {}, const States& states = {});
|
||||
|
||||
void SetEnv(Environment environment);
|
||||
|
||||
|
|
|
|||
|
|
@ -707,11 +707,8 @@ namespace Nz::ShaderAst
|
|||
if (!extVar.bindingIndex)
|
||||
throw AstError{ "external variable " + extVar.name + " requires a binding index" };
|
||||
|
||||
if (!extVar.bindingSet)
|
||||
throw AstError{ "external variable " + extVar.name + " requires a binding set" };
|
||||
|
||||
UInt64 bindingIndex = *extVar.bindingIndex;
|
||||
UInt64 bindingSet = *extVar.bindingSet;
|
||||
UInt64 bindingSet = extVar.bindingSet.value_or(0);
|
||||
UInt64 bindingKey = bindingSet << 32 | bindingIndex;
|
||||
if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end())
|
||||
throw AstError{ "Binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" };
|
||||
|
|
|
|||
|
|
@ -52,22 +52,6 @@ namespace Nz
|
|||
node.statement->Visit(*this);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareExternalStatement& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
for (auto& extVar : node.externalVars)
|
||||
{
|
||||
assert(extVar.bindingIndex);
|
||||
assert(extVar.bindingSet);
|
||||
|
||||
UInt64 set = *extVar.bindingSet;
|
||||
UInt64 binding = *extVar.bindingIndex;
|
||||
|
||||
bindings.insert(set << 32 | binding);
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
// Dismiss function if it's an entry point of another type than the one selected
|
||||
|
|
@ -113,7 +97,6 @@ namespace Nz
|
|||
|
||||
FunctionData* currentFunction = nullptr;
|
||||
|
||||
std::set<UInt64 /*set | binding*/> bindings;
|
||||
std::optional<ShaderStageType> selectedStage;
|
||||
std::unordered_map<std::size_t, FunctionData> functions;
|
||||
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
|
||||
|
|
@ -867,21 +850,32 @@ namespace Nz
|
|||
isStd140 = structInfo.layout == StructLayout::Std140;
|
||||
}
|
||||
|
||||
assert(externalVar.bindingIndex);
|
||||
assert(externalVar.bindingSet);
|
||||
if (!m_currentState->bindingMapping.empty() || isStd140)
|
||||
Append("layout(");
|
||||
|
||||
UInt64 bindingIndex = *externalVar.bindingIndex;
|
||||
UInt64 bindingSet = *externalVar.bindingSet;
|
||||
if (!m_currentState->bindingMapping.empty())
|
||||
{
|
||||
assert(externalVar.bindingIndex);
|
||||
|
||||
auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex);
|
||||
if (bindingIt == m_currentState->bindingMapping.end())
|
||||
throw std::runtime_error("no binding found for (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ")");
|
||||
UInt64 bindingIndex = *externalVar.bindingIndex;
|
||||
UInt64 bindingSet = externalVar.bindingSet.value_or(0);
|
||||
|
||||
auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex);
|
||||
if (bindingIt == m_currentState->bindingMapping.end())
|
||||
throw std::runtime_error("no binding found for (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ")");
|
||||
|
||||
Append("binding = ", bindingIt->second);
|
||||
if (isStd140)
|
||||
Append(", ");
|
||||
}
|
||||
|
||||
Append("layout(binding = ", bindingIt->second);
|
||||
if (isStd140)
|
||||
Append(", std140");
|
||||
Append("std140");
|
||||
|
||||
Append(") uniform ");
|
||||
if (!m_currentState->bindingMapping.empty() || isStd140)
|
||||
Append(") ");
|
||||
|
||||
Append("uniform ");
|
||||
|
||||
if (IsUniformType(externalVar.type))
|
||||
{
|
||||
|
|
|
|||
|
|
@ -482,8 +482,8 @@ namespace Nz::ShaderLang
|
|||
Expect(Advance(), TokenType::Colon);
|
||||
extVar.type = ParseType();
|
||||
|
||||
if (!extVar.bindingSet)
|
||||
extVar.bindingSet = blockSetIndex.value_or(0);
|
||||
if (!extVar.bindingSet && blockSetIndex)
|
||||
extVar.bindingSet = *blockSetIndex;
|
||||
|
||||
RegisterVariable(extVar.name);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ namespace Nz
|
|||
public:
|
||||
struct UniformVar
|
||||
{
|
||||
std::optional<UInt32> bindingIndex;
|
||||
std::optional<UInt32> descriptorSet;
|
||||
UInt32 bindingIndex;
|
||||
UInt32 descriptorSet;
|
||||
UInt32 pointerId;
|
||||
};
|
||||
|
||||
|
|
@ -123,10 +123,12 @@ namespace Nz
|
|||
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
|
||||
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
|
||||
|
||||
assert(extVar.bindingIndex);
|
||||
|
||||
UniformVar& uniformVar = extVars[varIndex++];
|
||||
uniformVar.pointerId = m_constantCache.Register(variable);
|
||||
uniformVar.bindingIndex = extVar.bindingIndex;
|
||||
uniformVar.descriptorSet = extVar.bindingSet;
|
||||
uniformVar.bindingIndex = *extVar.bindingIndex;
|
||||
uniformVar.descriptorSet = extVar.bindingSet.value_or(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -491,11 +493,8 @@ namespace Nz
|
|||
|
||||
for (auto&& [varIndex, extVar] : preVisitor.extVars)
|
||||
{
|
||||
if (extVar.bindingIndex)
|
||||
{
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, *extVar.descriptorSet);
|
||||
}
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, extVar.bindingIndex);
|
||||
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, extVar.descriptorSet);
|
||||
}
|
||||
|
||||
for (auto&& [varId, builtin] : preVisitor.builtinDecorations)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
#include <catch2/catch.hpp>
|
||||
#include <cctype>
|
||||
|
||||
void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput)
|
||||
void ExpectingGLSL(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
|
||||
{
|
||||
Nz::GlslWriter writer;
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ void ExpectingGLSL(Nz::ShaderAst::StatementPtr& shader, std::string_view expecte
|
|||
REQUIRE(subset == expectedOutput);
|
||||
}
|
||||
|
||||
void ExpectingSpirV(Nz::ShaderAst::StatementPtr& shader, std::string_view expectedOutput)
|
||||
void ExpectingSpirV(Nz::ShaderAst::Statement& shader, std::string_view expectedOutput)
|
||||
{
|
||||
Nz::SpirvWriter writer;
|
||||
auto spirv = writer.Generate(shader);
|
||||
|
|
@ -64,6 +64,7 @@ SCENARIO("Shader generation", "[Shader]")
|
|||
|
||||
auto external = std::make_unique<Nz::ShaderAst::DeclareExternalStatement>();
|
||||
external->externalVars.push_back({
|
||||
0,
|
||||
std::nullopt,
|
||||
"ubo",
|
||||
Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } }
|
||||
|
|
@ -85,7 +86,7 @@ SCENARIO("Shader generation", "[Shader]")
|
|||
|
||||
SECTION("Generating GLSL")
|
||||
{
|
||||
ExpectingGLSL(shader, R"(
|
||||
ExpectingGLSL(*shader, R"(
|
||||
void main()
|
||||
{
|
||||
float result = ubo.s.field.z;
|
||||
|
|
@ -94,7 +95,7 @@ void main()
|
|||
}
|
||||
SECTION("Generating Spir-V")
|
||||
{
|
||||
ExpectingSpirV(shader, R"(
|
||||
ExpectingSpirV(*shader, R"(
|
||||
OpFunction
|
||||
OpLabel
|
||||
OpVariable
|
||||
|
|
@ -121,7 +122,7 @@ OpFunctionEnd)");
|
|||
|
||||
SECTION("Generating GLSL")
|
||||
{
|
||||
ExpectingGLSL(shader, R"(
|
||||
ExpectingGLSL(*shader, R"(
|
||||
void main()
|
||||
{
|
||||
float result = ubo.s.field.z;
|
||||
|
|
@ -130,7 +131,7 @@ void main()
|
|||
}
|
||||
SECTION("Generating Spir-V")
|
||||
{
|
||||
ExpectingSpirV(shader, R"(
|
||||
ExpectingSpirV(*shader, R"(
|
||||
OpFunction
|
||||
OpLabel
|
||||
OpVariable
|
||||
|
|
|
|||
Loading…
Reference in New Issue