Shader: Move attribute parsing to parser, simplifying writer code

This commit is contained in:
Jérôme Leclercq
2021-04-14 11:34:21 +02:00
parent bca1561f73
commit aababb205f
14 changed files with 910 additions and 1046 deletions

View File

@@ -24,12 +24,6 @@ namespace Nz
static const char* s_outputPrefix = "_NzOut_";
static const char* s_outputVarName = "_nzOutput";
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id)
{
auto it = map.find(id);
@@ -49,29 +43,19 @@ namespace Nz
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
// Remove function if it's an entry point of another type than the one selected
bool isEntryPoint = false;
for (auto& attribute : func->attributes)
if (node.entryStage)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
ShaderStageType stage = *node.entryStage;
if (stage != selectedStage)
return ShaderBuilder::NoOp();
if (it->second != selectedEntryPoint)
return ShaderBuilder::NoOp();
isEntryPoint = true;
break;
}
}
if (isEntryPoint)
entryPoint = func;
}
return clone;
}
ShaderStageType selectedEntryPoint;
ShaderStageType selectedStage;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
};
@@ -81,8 +65,8 @@ namespace Nz
ShaderStageTypeFlags stageFlags;
};
std::unordered_map<std::string, Builtin> builtinMapping = {
{ "position", { "gl_Position", ShaderStageType::Vertex } }
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ ShaderAst::BuiltinEntry::VertexPosition, { "gl_Position", ShaderStageType::Vertex } }
};
}
@@ -95,17 +79,11 @@ namespace Nz
std::string targetName;
};
struct StructInfo
{
ShaderAst::StructDescription structDesc;
bool isStd140 = false;
};
ShaderStageType stage;
const States* states = nullptr;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
std::unordered_map<std::size_t, StructInfo> structs;
std::unordered_map<std::size_t, ShaderAst::StructDescription> structs;
std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields;
@@ -138,7 +116,7 @@ namespace Nz
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
PreVisitor previsitor;
previsitor.selectedEntryPoint = shaderStage;
previsitor.selectedStage = shaderStage;
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader);
@@ -241,13 +219,13 @@ namespace Nz
void GlslWriter::Append(const ShaderAst::StructType& structType)
{
const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex).structDesc;
const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex);
Append(structDesc.name);
}
void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
{
/* TODO */
throw std::runtime_error("unexpected UniformType");
}
void GlslWriter::Append(const ShaderAst::VectorType& vecType)
@@ -305,7 +283,7 @@ namespace Nz
void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers)
{
const auto& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc;
const auto& structDesc = Retrieve(m_currentState->structs, structIndex);
const auto& member = structDesc.members[*memberIndices];
@@ -371,7 +349,7 @@ namespace Nz
assert(IsStructType(parameter.type));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex);
AppendLine(structDesc.name, " ", varName, ";");
for (const auto& [memberName, targetName] : m_currentState->inputFields)
@@ -397,41 +375,24 @@ namespace Nz
{
for (const auto& member : structDesc.members)
{
bool skip = false;
std::optional<std::string> builtinName;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
if (member.builtin)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = builtinMapping.find(std::get<std::string>(attributeParam));
if (it != builtinMapping.end())
{
const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(m_currentState->stage))
{
skip = true;
break;
}
auto it = s_builtinMapping.find(member.builtin.value());
assert(it != s_builtinMapping.end());
builtinName = builtin.identifier;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(m_currentState->stage))
continue; //< This builtin is not active in this stage, skip it
fields.push_back({
member.name,
builtin.identifier
});
}
if (skip)
continue;
if (attributeLocation)
else if (member.locationIndex)
{
Append("layout(location = ");
Append(*attributeLocation);
Append(*member.locationIndex);
Append(") ");
Append(keyword);
Append(" ");
@@ -446,13 +407,6 @@ namespace Nz
targetPrefix + member.name
});
}
else if (builtinName)
{
fields.push_back({
member.name,
*builtinName
});
}
}
AppendLine();
};
@@ -468,7 +422,7 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
inputStruct = &Retrieve(m_currentState->structs, inputStructIndex).structDesc;
inputStruct = &Retrieve(m_currentState->structs, inputStructIndex);
AppendCommentSection("Inputs");
AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix);
@@ -485,20 +439,17 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
const ShaderAst::StructDescription& outputStruct = Retrieve(m_currentState->structs, outputStructIndex).structDesc;
const ShaderAst::StructDescription& outputStruct = Retrieve(m_currentState->structs, outputStructIndex);
AppendCommentSection("Outputs");
AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix);
}
}
void GlslWriter::RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc)
void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc)
{
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, State::StructInfo{
std::move(desc),
isStd140
});
m_currentState->structs.emplace(structIndex, std::move(desc));
}
void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName)
@@ -667,16 +618,6 @@ namespace Nz
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
for (const auto& [attributeType, attributeParam] : externalVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
bindingIndex = std::get<long long>(attributeParam);
break;
}
}
bool isStd140 = false;
if (IsUniformType(externalVar.type))
{
@@ -685,13 +626,13 @@ namespace Nz
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = Retrieve(m_currentState->structs, structIndex);
isStd140 = structInfo.isStd140;
isStd140 = structInfo.layout == StructLayout_Std140;
}
if (bindingIndex)
if (externalVar.bindingIndex)
{
Append("layout(binding = ");
Append(*bindingIndex);
Append(*externalVar.bindingIndex);
if (isStd140)
Append(", std140");
@@ -708,19 +649,19 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = Retrieve(m_currentState->structs, structIndex);
auto& structDesc = Retrieve(m_currentState->structs, structIndex);
bool first = true;
for (const auto& [name, attribute, type] : structInfo.structDesc.members)
for (const auto& member : structDesc.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(member.type);
Append(" ");
Append(name);
Append(member.name);
Append(";");
}
}
@@ -778,34 +719,24 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
bool isStd140 = false;
for (const auto& [attributeType, attributeParam] : node.attributes)
{
if (attributeType == ShaderAst::AttributeType::Layout && std::get<std::string>(attributeParam) == "std140")
{
isStd140 = true;
break;
}
}
assert(node.structIndex);
RegisterStruct(*node.structIndex, isStd140, node.description);
RegisterStruct(*node.structIndex, node.description);
Append("struct ");
AppendLine(node.description.name);
EnterScope();
{
bool first = true;
for (const auto& [name, attribute, type] : node.description.members)
for (const auto& member : node.description.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(member.type);
Append(" ");
Append(name);
Append(member.name);
Append(";");
}
}
@@ -897,7 +828,7 @@ namespace Nz
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
assert(IsStructType(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName;
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)

View File

@@ -45,7 +45,6 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareExternalStatement>();
clone->attributes = node.attributes;
clone->externalVars = node.externalVars;
clone->varIndex = node.varIndex;
@@ -55,7 +54,7 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->attributes = node.attributes;
clone->entryStage = node.entryStage;
clone->funcIndex = node.funcIndex;
clone->name = node.name;
clone->parameters = node.parameters;

View File

@@ -115,6 +115,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareVariableStatement& node)
{
OptVal(node.varIndex);
Value(node.varName);
Type(node.varType);
Node(node.initialExpression);
@@ -168,14 +169,14 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareExternalStatement& node)
{
Attributes(node.attributes);
OptVal(node.varIndex);
Container(node.externalVars);
for (auto& extVar : node.externalVars)
{
Attributes(extVar.attributes);
Value(extVar.name);
Type(extVar.type);
OptVal(extVar.bindingIndex);
}
}
@@ -183,8 +184,9 @@ namespace Nz::ShaderAst
{
Value(node.name);
Type(node.returnType);
Attributes(node.attributes);
OptEnum(node.entryStage);
OptVal(node.funcIndex);
OptVal(node.varIndex);
Container(node.parameters);
for (auto& parameter : node.parameters)
@@ -200,13 +202,18 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareStructStatement& node)
{
OptVal(node.structIndex);
Value(node.description.name);
OptEnum(node.description.layout);
Container(node.description.members);
for (auto& member : node.description.members)
{
Value(member.name);
Type(member.type);
OptEnum(member.builtin);
OptVal(member.locationIndex);
}
}
@@ -246,78 +253,6 @@ namespace Nz::ShaderAst
m_stream.FlushBits();
}
void AstSerializerBase::Attributes(std::vector<Attribute>& attributes)
{
Container(attributes);
for (auto& attribute : attributes)
{
Enum(attribute.type);
if (IsWriting())
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::monostate>)
{
UInt8 typeId = 0;
Value(typeId);
}
else if constexpr (std::is_same_v<T, long long>)
{
UInt8 typeId = 1;
UInt64 v = UInt64(arg);
Value(typeId);
Value(v);
}
else if constexpr (std::is_same_v<T, std::string>)
{
UInt8 typeId = 2;
Value(typeId);
Value(arg);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, attribute.args);
}
else
{
UInt8 typeId;
Value(typeId);
switch (typeId)
{
case 0:
attribute.args.emplace<std::monostate>();
break;
case 1:
{
UInt64 arg;
Value(arg);
attribute.args = static_cast<long long>(arg);
break;
}
case 2:
{
std::string arg;
Value(arg);
attribute.args = std::move(arg);
break;
}
default:
throw std::runtime_error("invalid attribute type id");
}
}
}
}
bool ShaderAstSerializer::IsWriting() const
{
return true;

View File

@@ -11,14 +11,6 @@
namespace Nz::ShaderAst
{
namespace
{
std::unordered_map<std::string, ShaderStageType> entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
}
struct AstError
{
std::string errMsg;
@@ -28,7 +20,7 @@ namespace Nz::ShaderAst
{
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<long long> usedBindingIndexes;
std::unordered_set<unsigned int> usedBindingIndexes;
};
bool AstValidator::Validate(StatementPtr& node, std::string* error)
@@ -552,44 +544,15 @@ namespace Nz::ShaderAst
void AstValidator::Visit(DeclareExternalStatement& node)
{
if (!node.attributes.empty())
throw AstError{ "unhandled attribute for external block" };
/*for (const auto& [attributeType, arg] : node.attributes)
{
switch (attributeType)
{
default:
throw AstError{ "unhandled attribute for external block" };
}
}*/
for (const auto& extVar : node.externalVars)
{
bool hasBinding = false;
for (const auto& [attributeType, arg] : extVar.attributes)
if (extVar.bindingIndex)
{
switch (attributeType)
{
case AttributeType::Binding:
{
if (hasBinding)
throw AstError{ "attribute binding must be present once" };
unsigned int bindingIndex = extVar.bindingIndex.value();
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
if (!std::holds_alternative<long long>(arg))
throw AstError{ "attribute binding requires a string parameter" };
long long bindingIndex = std::get<long long>(arg);
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
m_context->usedBindingIndexes.insert(bindingIndex);
break;
}
default:
throw AstError{ "unhandled attribute for external variable" };
}
m_context->usedBindingIndexes.insert(bindingIndex);
}
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
@@ -603,42 +566,17 @@ namespace Nz::ShaderAst
void AstValidator::Visit(DeclareFunctionStatement& node)
{
bool hasEntry = false;
for (const auto& [attributeType, arg] : node.attributes)
if (node.entryStage)
{
switch (attributeType)
{
case AttributeType::Entry:
{
if (hasEntry)
throw AstError{ "attribute entry must be present once" };
ShaderStageType stageType = *node.entryStage;
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute entry requires a string parameter" };
if (m_context->entryFunctions[UnderlyingCast(stageType)])
throw AstError{ "the same entry type has been defined multiple times" };
const std::string& argStr = std::get<std::string>(arg);
m_context->entryFunctions[UnderlyingCast(stageType)] = &node;
auto it = entryPoints.find(argStr);
if (it == entryPoints.end())
throw AstError{ "invalid parameter " + argStr + " for entry attribute" };
ShaderStageType stageType = it->second;
if (m_context->entryFunctions[UnderlyingCast(stageType)])
throw AstError{ "the same entry type has been defined multiple times" };
m_context->entryFunctions[UnderlyingCast(it->second)] = &node;
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
hasEntry = true;
break;
}
default:
throw AstError{ "unhandled attribute for function" };
}
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
}
for (auto& statement : node.statements)

File diff suppressed because it is too large Load Diff

View File

@@ -27,12 +27,6 @@ namespace Nz
{
namespace
{
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
struct Builtin
{
const char* debugName;
@@ -40,8 +34,8 @@ namespace Nz
SpirvBuiltIn decoration;
};
std::unordered_map<std::string, Builtin> s_builtinMapping = {
{ "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
};
class PreVisitor : public ShaderAst::AstScopedVisitor
@@ -129,32 +123,13 @@ namespace Nz
UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable);
for (const auto& [attributeType, attributeParam] : extVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
uniformVar.bindingIndex = std::get<long long>(attributeParam);
break;
}
}
uniformVar.bindingIndex = extVar.bindingIndex;
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
std::optional<ShaderStageType> entryPointType;
for (auto& attribute : node.attributes)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
entryPointType = it->second;
break;
}
}
std::optional<ShaderStageType> entryPointType = node.entryStage;
assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex;
@@ -325,29 +300,12 @@ namespace Nz
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
{
std::optional<std::reference_wrapper<Builtin>> builtinOpt;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
if (member.builtin)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
if (it != s_builtinMapping.end())
{
builtinOpt = it->second;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
auto it = s_builtinMapping.find(*member.builtin);
assert(it != s_builtinMapping.end());
if (builtinOpt)
{
Builtin& builtin = *builtinOpt;
Builtin& builtin = it->second;
if ((builtin.compatibleStages & entryPointType) == 0)
return 0;
@@ -364,7 +322,7 @@ namespace Nz
return varId;
}
else if (attributeLocation)
else if (member.locationIndex)
{
SpirvConstantCache::Variable variable;
variable.debugName = member.name;
@@ -373,7 +331,7 @@ namespace Nz
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = *attributeLocation;
locationDecorations[varId] = *member.locationIndex;
return varId;
}