Add initial support for shader binding sets (WIP)

This commit is contained in:
Jérôme Leclercq
2021-06-14 22:35:05 +02:00
parent 815a7b0c62
commit f22b501e25
53 changed files with 885 additions and 511 deletions

View File

@@ -221,6 +221,7 @@ namespace Nz::ShaderAst
Value(extVar.name);
Type(extVar.type);
OptVal(extVar.bindingIndex);
OptVal(extVar.bindingSet);
}
}

View File

@@ -41,7 +41,7 @@ namespace Nz::ShaderAst
Options options;
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<unsigned int> usedBindingIndexes;
std::unordered_set<UInt64> usedBindingIndexes;
FunctionData* currentFunction = nullptr;
};
@@ -704,14 +704,19 @@ namespace Nz::ShaderAst
for (const auto& extVar : node.externalVars)
{
if (extVar.bindingIndex)
{
unsigned int bindingIndex = extVar.bindingIndex.value();
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
if (!extVar.bindingIndex)
throw AstError{ "external variable " + extVar.name + " requires a binding index" };
m_context->usedBindingIndexes.insert(bindingIndex);
}
if (!extVar.bindingSet)
throw AstError{ "external variable " + extVar.name + " requires a binding set" };
UInt64 bindingIndex = *extVar.bindingIndex;
UInt64 bindingSet = *extVar.bindingSet;
UInt64 bindingKey = bindingSet << 32 | bindingIndex;
if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" };
m_context->usedBindingIndexes.insert(bindingKey);
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
throw AstError{ "External variable " + extVar.name + " is already declared" };

View File

@@ -14,6 +14,7 @@
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <optional>
#include <set>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
@@ -51,6 +52,22 @@ namespace Nz
node.statement->Visit(*this);
}
void Visit(ShaderAst::DeclareExternalStatement& node) override
{
AstRecursiveVisitor::Visit(node);
for (auto& extVar : node.externalVars)
{
assert(extVar.bindingIndex);
assert(extVar.bindingSet);
UInt64 set = *extVar.bindingSet;
UInt64 binding = *extVar.bindingIndex;
bindings.insert(set << 32 | binding);
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
// Dismiss function if it's an entry point of another type than the one selected
@@ -96,6 +113,7 @@ namespace Nz
FunctionData* currentFunction = nullptr;
std::set<UInt64 /*set | binding*/> bindings;
std::optional<ShaderStageType> selectedStage;
std::unordered_map<std::size_t, FunctionData> functions;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
@@ -118,6 +136,11 @@ namespace Nz
struct GlslWriter::State
{
State(const GlslWriter::BindingMapping& bindings) :
bindingMapping(bindings)
{
}
struct InOutField
{
std::string memberName;
@@ -131,6 +154,7 @@ namespace Nz
std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields;
Bitset<> declaredFunctions;
const GlslWriter::BindingMapping& bindingMapping;
PreVisitor previsitor;
const States* states = nullptr;
UInt64 enabledOptions = 0;
@@ -138,9 +162,9 @@ namespace Nz
unsigned int indentLevel = 0;
};
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const States& states)
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states)
{
State state;
State state(bindingMapping);
state.enabledOptions = states.enabledOptions;
state.stage = shaderStage;
@@ -254,7 +278,7 @@ namespace Nz
{
case ShaderAst::PrimitiveType::Boolean: return Append("bool");
case ShaderAst::PrimitiveType::Float32: return Append("float");
case ShaderAst::PrimitiveType::Int32: return Append("ivec2");
case ShaderAst::PrimitiveType::Int32: return Append("int");
case ShaderAst::PrimitiveType::UInt32: return Append("uint");
}
}
@@ -521,7 +545,7 @@ namespace Nz
{
if (node.entryStage == ShaderStageType::Fragment && node.earlyFragmentTests && *node.earlyFragmentTests)
{
if ((m_environment.glES && m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1) || (!m_environment.glES && m_environment.glMajorVersion >= 4 && m_environment.glMinorVersion >= 2) || m_environment.extCallback("GL_ARB_shader_image_load_store"))
if ((m_environment.glES && m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1) || (!m_environment.glES && m_environment.glMajorVersion >= 4 && m_environment.glMinorVersion >= 2) || (m_environment.extCallback && m_environment.extCallback("GL_ARB_shader_image_load_store")))
{
AppendLine("layout(early_fragment_tests) in;");
AppendLine();
@@ -843,54 +867,60 @@ namespace Nz
isStd140 = structInfo.layout == StructLayout::Std140;
}
if (externalVar.bindingIndex)
assert(externalVar.bindingIndex);
assert(externalVar.bindingSet);
UInt64 bindingIndex = *externalVar.bindingIndex;
UInt64 bindingSet = *externalVar.bindingSet;
auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex);
if (bindingIt == m_currentState->bindingMapping.end())
throw std::runtime_error("no binding found for (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ")");
Append("layout(binding = ", bindingIt->second);
if (isStd140)
Append(", std140");
Append(") uniform ");
if (IsUniformType(externalVar.type))
{
Append("layout(binding = ");
Append(*externalVar.bindingIndex);
if (isStd140)
Append(", std140");
Append("_NzBinding_");
AppendLine(externalVar.name);
Append(") uniform ");
if (IsUniformType(externalVar.type))
EnterScope();
{
Append("_NzBinding_");
AppendLine(externalVar.name);
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
EnterScope();
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structDesc = Retrieve(m_currentState->structs, structIndex);
bool first = true;
for (const auto& member : structDesc.members)
{
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
if (!first)
AppendLine();
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structDesc = Retrieve(m_currentState->structs, structIndex);
first = false;
bool first = true;
for (const auto& member : structDesc.members)
{
if (!first)
AppendLine();
first = false;
Append(member.type);
Append(" ");
Append(member.name);
Append(";");
}
Append(member.type);
Append(" ");
Append(member.name);
Append(";");
}
LeaveScope(false);
}
else
Append(externalVar.type);
Append(" ");
Append(externalVar.name);
AppendLine(";");
if (IsUniformType(externalVar.type))
AppendLine();
LeaveScope(false);
}
else
Append(externalVar.type);
Append(" ");
Append(externalVar.name);
AppendLine(";");
if (IsUniformType(externalVar.type))
AppendLine();
RegisterVariable(varIndex++, externalVar.name);
}

View File

@@ -29,7 +29,7 @@ namespace Nz
struct LangWriter::BindingAttribute
{
std::optional<unsigned int> bindingIndex;
std::optional<UInt32> bindingIndex;
inline bool HasValue() const { return bindingIndex.has_value(); }
};
@@ -71,11 +71,18 @@ namespace Nz
struct LangWriter::LocationAttribute
{
std::optional<unsigned int> locationIndex;
std::optional<UInt32> locationIndex;
inline bool HasValue() const { return locationIndex.has_value(); }
};
struct LangWriter::SetAttribute
{
std::optional<UInt32> setIndex;
inline bool HasValue() const { return setIndex.has_value(); }
};
struct LangWriter::State
{
const States* states = nullptr;
@@ -87,7 +94,7 @@ namespace Nz
unsigned int indentLevel = 0;
};
std::string LangWriter::Generate(ShaderAst::StatementPtr& shader, const States& states)
std::string LangWriter::Generate(ShaderAst::Statement& shader, const States& states)
{
State state;
m_currentState = &state;
@@ -220,8 +227,10 @@ namespace Nz
if (!hasAnyAttribute)
return;
bool first = true;
Append("[");
(AppendAttribute(params), ...);
AppendAttributesInternal(first, std::forward<Args>(params)...);
Append("]");
if (appendLine)
@@ -230,6 +239,27 @@ namespace Nz
Append(" ");
}
template<typename T>
void LangWriter::AppendAttributesInternal(bool& first, const T& param)
{
if (!param.HasValue())
return;
if (!first)
Append(", ");
first = false;
AppendAttribute(param);
}
template<typename T1, typename T2, typename... Rest>
void LangWriter::AppendAttributesInternal(bool& first, const T1& firstParam, const T2& secondParam, Rest&&... params)
{
AppendAttributesInternal(first, firstParam);
AppendAttributesInternal(first, secondParam, std::forward<Rest>(params)...);
}
void LangWriter::AppendAttribute(BindingAttribute binding)
{
if (!binding.HasValue())
@@ -333,6 +363,14 @@ namespace Nz
Append("location(", *location.locationIndex, ")");
}
void LangWriter::AppendAttribute(SetAttribute set)
{
if (!set.HasValue())
return;
Append("set(", *set.setIndex, ")");
}
void LangWriter::AppendCommentSection(const std::string& section)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@@ -607,7 +645,7 @@ namespace Nz
first = false;
AppendAttributes(false, BindingAttribute{ externalVar.bindingIndex });
AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex });
Append(externalVar.name, ": ", externalVar.type);
RegisterVariable(varIndex++, externalVar.name);

View File

@@ -35,6 +35,7 @@ namespace Nz::ShaderLang
{ "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
{ "opt", ShaderAst::AttributeType::Option },
{ "set", ShaderAst::AttributeType::Set },
};
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
@@ -379,8 +380,31 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
{
if (!attributes.empty())
throw AttributeError{ "unhandled attribute for external block" };
std::optional<UInt32> blockSetIndex;
for (const auto& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Set:
{
if (blockSetIndex)
throw AttributeError{ "attribute set must be present once" };
if (!std::holds_alternative<long long>(arg))
throw AttributeError{ "attribute set requires a string parameter" };
std::optional<UInt32> bindingIndex = BoundCast<UInt32>(std::get<long long>(arg));
if (!bindingIndex)
throw AttributeError{ "invalid set index" };
blockSetIndex = bindingIndex.value();
break;
}
default:
throw AttributeError{ "unhandled attribute for external block" };
}
}
Expect(Advance(), TokenType::External);
Expect(Advance(), TokenType::OpenCurlyBracket);
@@ -424,7 +448,7 @@ namespace Nz::ShaderLang
if (!std::holds_alternative<long long>(arg))
throw AttributeError{ "attribute binding requires a string parameter" };
std::optional<unsigned int> bindingIndex = BoundCast<unsigned int>(std::get<long long>(arg));
std::optional<UInt32> bindingIndex = BoundCast<UInt32>(std::get<long long>(arg));
if (!bindingIndex)
throw AttributeError{ "invalid binding index" };
@@ -432,6 +456,22 @@ namespace Nz::ShaderLang
break;
}
case ShaderAst::AttributeType::Set:
{
if (extVar.bindingSet)
throw AttributeError{ "attribute set must be present once" };
if (!std::holds_alternative<long long>(arg))
throw AttributeError{ "attribute set requires a string parameter" };
std::optional<UInt32> bindingIndex = BoundCast<UInt32>(std::get<long long>(arg));
if (!bindingIndex)
throw AttributeError{ "invalid set index" };
extVar.bindingSet = bindingIndex.value();
break;
}
default:
throw AttributeError{ "unhandled attribute for external variable" };
}
@@ -442,6 +482,9 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon);
extVar.type = ParseType();
if (!extVar.bindingSet)
extVar.bindingSet = blockSetIndex.value_or(0);
RegisterVariable(extVar.name);
}
@@ -704,7 +747,7 @@ namespace Nz::ShaderLang
if (structField.locationIndex)
throw AttributeError{ "attribute location must be present once" };
structField.locationIndex = BoundCast<unsigned int>(std::get<long long>(attributeParam));
structField.locationIndex = BoundCast<UInt32>(std::get<long long>(attributeParam));
if (!structField.locationIndex)
throw AttributeError{ "invalid location index" };

View File

@@ -47,6 +47,7 @@ namespace Nz
struct UniformVar
{
std::optional<UInt32> bindingIndex;
std::optional<UInt32> descriptorSet;
UInt32 pointerId;
};
@@ -125,6 +126,7 @@ namespace Nz
UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable);
uniformVar.bindingIndex = extVar.bindingIndex;
uniformVar.descriptorSet = extVar.bindingSet;
}
}
@@ -492,7 +494,7 @@ namespace Nz
if (extVar.bindingIndex)
{
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::Binding, *extVar.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, 0);
state.annotations.Append(SpirvOp::OpDecorate, extVar.pointerId, SpirvDecoration::DescriptorSet, *extVar.descriptorSet);
}
}