Shader: Attribute can now have expressions as values and struct fields can be conditionally supported

This commit is contained in:
Jérôme Leclercq
2021-07-07 11:41:58 +02:00
parent 749b40cb31
commit f9af35b489
36 changed files with 945 additions and 600 deletions

View File

@@ -57,7 +57,7 @@ namespace Nz
using ExtVarContainer = std::unordered_map<std::size_t /*varIndex*/, UniformVar>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
using StructContainer = std::vector<ShaderAst::StructDescription>;
using StructContainer = std::vector<ShaderAst::StructDescription*>;
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
m_states(conditions),
@@ -68,7 +68,7 @@ namespace Nz
m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription&
{
assert(structIndex < declaredStructs.size());
return declaredStructs[structIndex];
return *declaredStructs[structIndex];
});
}
@@ -88,18 +88,12 @@ namespace Nz
void Visit(ShaderAst::ConditionalExpression& node) override
{
if (TestBit<Nz::UInt64>(m_states.enabledOptions, node.optionIndex))
node.truePath->Visit(*this);
else
node.falsePath->Visit(*this);
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
}
void Visit(ShaderAst::ConditionalStatement& node) override
{
if (TestBit<Nz::UInt64>(m_states.enabledOptions, node.optionIndex))
node.statement->Visit(*this);
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
}
void Visit(ShaderAst::ConstantExpression& node) override
@@ -123,12 +117,12 @@ namespace Nz
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
assert(extVar.bindingIndex);
assert(extVar.bindingIndex.IsResultingValue());
UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable);
uniformVar.bindingIndex = *extVar.bindingIndex;
uniformVar.descriptorSet = extVar.bindingSet.value_or(0);
uniformVar.bindingIndex = extVar.bindingIndex.GetResultingValue();
uniformVar.descriptorSet = (extVar.bindingSet.HasValue()) ? extVar.bindingSet.GetResultingValue() : 0;
}
}
@@ -151,7 +145,9 @@ namespace Nz
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
std::optional<ShaderStageType> entryPointType = node.entryStage;
std::optional<ShaderStageType> entryPointType;
if (node.entryStage.HasValue())
entryPointType = node.entryStage.GetResultingValue();
assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex;
@@ -188,14 +184,14 @@ namespace Nz
if (*entryPointType == ShaderStageType::Fragment)
{
executionModes.push_back(SpirvExecutionMode::OriginUpperLeft);
if (node.earlyFragmentTests && *node.earlyFragmentTests)
if (node.earlyFragmentTests.HasValue() && node.earlyFragmentTests.GetResultingValue())
executionModes.push_back(SpirvExecutionMode::EarlyFragmentTests);
if (node.depthWrite)
if (node.depthWrite.HasValue())
{
executionModes.push_back(SpirvExecutionMode::DepthReplacing);
switch (*node.depthWrite)
switch (node.depthWrite.GetResultingValue())
{
case ShaderAst::DepthWriteMode::Replace: break;
case ShaderAst::DepthWriteMode::Greater: executionModes.push_back(SpirvExecutionMode::DepthGreater); break;
@@ -217,10 +213,10 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc.members)
for (const auto& member : structDesc->members)
{
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
{
@@ -247,10 +243,10 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
for (const auto& member : structDesc.members)
for (const auto& member : structDesc->members)
{
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
{
@@ -291,7 +287,7 @@ namespace Nz
if (structIndex >= declaredStructs.size())
declaredStructs.resize(structIndex + 1);
declaredStructs[structIndex] = node.description;
declaredStructs[structIndex] = &node.description;
m_constantCache.Register(*m_constantCache.BuildType(node.description));
}
@@ -357,9 +353,9 @@ namespace Nz
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
{
if (member.builtin)
if (member.builtin.HasValue())
{
auto it = s_builtinMapping.find(*member.builtin);
auto it = s_builtinMapping.find(member.builtin.GetResultingValue());
assert(it != s_builtinMapping.end());
Builtin& builtin = it->second;
@@ -379,7 +375,7 @@ namespace Nz
return varId;
}
else if (member.locationIndex)
else if (member.locationIndex.HasValue())
{
SpirvConstantCache::Variable variable;
variable.debugName = member.name;
@@ -388,7 +384,7 @@ namespace Nz
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = *member.locationIndex;
locationDecorations[varId] = member.locationIndex.GetResultingValue();
return varId;
}
@@ -453,7 +449,10 @@ namespace Nz
ShaderAst::StatementPtr sanitizedAst;
if (!states.sanitized)
{
sanitizedAst = ShaderAst::Sanitize(shader);
ShaderAst::SanitizeVisitor::Options options;
options.enabledOptions = states.enabledOptions;
sanitizedAst = ShaderAst::Sanitize(shader, options);
targetAst = sanitizedAst.get();
}