Shader: Attribute can now have expressions as values and struct fields can be conditionally supported
This commit is contained in:
@@ -57,7 +57,7 @@ namespace Nz
|
||||
using ExtVarContainer = std::unordered_map<std::size_t /*varIndex*/, UniformVar>;
|
||||
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
|
||||
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
|
||||
using StructContainer = std::vector<ShaderAst::StructDescription>;
|
||||
using StructContainer = std::vector<ShaderAst::StructDescription*>;
|
||||
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
|
||||
m_states(conditions),
|
||||
@@ -68,7 +68,7 @@ namespace Nz
|
||||
m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription&
|
||||
{
|
||||
assert(structIndex < declaredStructs.size());
|
||||
return declaredStructs[structIndex];
|
||||
return *declaredStructs[structIndex];
|
||||
});
|
||||
}
|
||||
|
||||
@@ -88,18 +88,12 @@ namespace Nz
|
||||
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override
|
||||
{
|
||||
if (TestBit<Nz::UInt64>(m_states.enabledOptions, node.optionIndex))
|
||||
node.truePath->Visit(*this);
|
||||
else
|
||||
node.falsePath->Visit(*this);
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConditionalStatement& node) override
|
||||
{
|
||||
if (TestBit<Nz::UInt64>(m_states.enabledOptions, node.optionIndex))
|
||||
node.statement->Visit(*this);
|
||||
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConstantExpression& node) override
|
||||
@@ -123,12 +117,12 @@ namespace Nz
|
||||
variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform;
|
||||
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
|
||||
|
||||
assert(extVar.bindingIndex);
|
||||
assert(extVar.bindingIndex.IsResultingValue());
|
||||
|
||||
UniformVar& uniformVar = extVars[varIndex++];
|
||||
uniformVar.pointerId = m_constantCache.Register(variable);
|
||||
uniformVar.bindingIndex = *extVar.bindingIndex;
|
||||
uniformVar.descriptorSet = extVar.bindingSet.value_or(0);
|
||||
uniformVar.bindingIndex = extVar.bindingIndex.GetResultingValue();
|
||||
uniformVar.descriptorSet = (extVar.bindingSet.HasValue()) ? extVar.bindingSet.GetResultingValue() : 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +145,9 @@ namespace Nz
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
std::optional<ShaderStageType> entryPointType = node.entryStage;
|
||||
std::optional<ShaderStageType> entryPointType;
|
||||
if (node.entryStage.HasValue())
|
||||
entryPointType = node.entryStage.GetResultingValue();
|
||||
|
||||
assert(node.funcIndex);
|
||||
std::size_t funcIndex = *node.funcIndex;
|
||||
@@ -188,14 +184,14 @@ namespace Nz
|
||||
if (*entryPointType == ShaderStageType::Fragment)
|
||||
{
|
||||
executionModes.push_back(SpirvExecutionMode::OriginUpperLeft);
|
||||
if (node.earlyFragmentTests && *node.earlyFragmentTests)
|
||||
if (node.earlyFragmentTests.HasValue() && node.earlyFragmentTests.GetResultingValue())
|
||||
executionModes.push_back(SpirvExecutionMode::EarlyFragmentTests);
|
||||
|
||||
if (node.depthWrite)
|
||||
if (node.depthWrite.HasValue())
|
||||
{
|
||||
executionModes.push_back(SpirvExecutionMode::DepthReplacing);
|
||||
|
||||
switch (*node.depthWrite)
|
||||
switch (node.depthWrite.GetResultingValue())
|
||||
{
|
||||
case ShaderAst::DepthWriteMode::Replace: break;
|
||||
case ShaderAst::DepthWriteMode::Greater: executionModes.push_back(SpirvExecutionMode::DepthGreater); break;
|
||||
@@ -217,10 +213,10 @@ namespace Nz
|
||||
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
|
||||
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
|
||||
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
|
||||
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
|
||||
|
||||
std::size_t memberIndex = 0;
|
||||
for (const auto& member : structDesc.members)
|
||||
for (const auto& member : structDesc->members)
|
||||
{
|
||||
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0)
|
||||
{
|
||||
@@ -247,10 +243,10 @@ namespace Nz
|
||||
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
|
||||
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
|
||||
const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex];
|
||||
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
|
||||
|
||||
std::size_t memberIndex = 0;
|
||||
for (const auto& member : structDesc.members)
|
||||
for (const auto& member : structDesc->members)
|
||||
{
|
||||
if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0)
|
||||
{
|
||||
@@ -291,7 +287,7 @@ namespace Nz
|
||||
if (structIndex >= declaredStructs.size())
|
||||
declaredStructs.resize(structIndex + 1);
|
||||
|
||||
declaredStructs[structIndex] = node.description;
|
||||
declaredStructs[structIndex] = &node.description;
|
||||
|
||||
m_constantCache.Register(*m_constantCache.BuildType(node.description));
|
||||
}
|
||||
@@ -357,9 +353,9 @@ namespace Nz
|
||||
|
||||
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
|
||||
{
|
||||
if (member.builtin)
|
||||
if (member.builtin.HasValue())
|
||||
{
|
||||
auto it = s_builtinMapping.find(*member.builtin);
|
||||
auto it = s_builtinMapping.find(member.builtin.GetResultingValue());
|
||||
assert(it != s_builtinMapping.end());
|
||||
|
||||
Builtin& builtin = it->second;
|
||||
@@ -379,7 +375,7 @@ namespace Nz
|
||||
|
||||
return varId;
|
||||
}
|
||||
else if (member.locationIndex)
|
||||
else if (member.locationIndex.HasValue())
|
||||
{
|
||||
SpirvConstantCache::Variable variable;
|
||||
variable.debugName = member.name;
|
||||
@@ -388,7 +384,7 @@ namespace Nz
|
||||
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
|
||||
|
||||
UInt32 varId = m_constantCache.Register(variable);
|
||||
locationDecorations[varId] = *member.locationIndex;
|
||||
locationDecorations[varId] = member.locationIndex.GetResultingValue();
|
||||
|
||||
return varId;
|
||||
}
|
||||
@@ -453,7 +449,10 @@ namespace Nz
|
||||
ShaderAst::StatementPtr sanitizedAst;
|
||||
if (!states.sanitized)
|
||||
{
|
||||
sanitizedAst = ShaderAst::Sanitize(shader);
|
||||
ShaderAst::SanitizeVisitor::Options options;
|
||||
options.enabledOptions = states.enabledOptions;
|
||||
|
||||
sanitizedAst = ShaderAst::Sanitize(shader, options);
|
||||
targetAst = sanitizedAst.get();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user