Shader: Move attribute parsing to parser, simplifying writer code

This commit is contained in:
Jérôme Leclercq 2021-04-14 11:34:21 +02:00
parent bca1561f73
commit aababb205f
14 changed files with 910 additions and 1046 deletions

View File

@ -11,6 +11,7 @@
#include <Nazara/Utility/Enums.hpp> #include <Nazara/Utility/Enums.hpp>
#include <Nazara/Shader/ShaderEnums.hpp> #include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp> #include <Nazara/Shader/Ast/Attribute.hpp>
#include <optional>
#include <string> #include <string>
#include <variant> #include <variant>
#include <vector> #include <vector>
@ -81,11 +82,13 @@ namespace Nz::ShaderAst
{ {
struct StructMember struct StructMember
{ {
std::optional<BuiltinEntry> builtin;
std::optional<unsigned int> locationIndex;
std::string name; std::string name;
std::vector<Attribute> attributes;
ExpressionType type; ExpressionType type;
}; };
std::optional<StructLayout> layout;
std::string name; std::string name;
std::vector<StructMember> members; std::vector<StructMember> members;
}; };

View File

@ -71,7 +71,7 @@ namespace Nz
void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node); void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node);
void HandleInOut(); void HandleInOut();
void RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc); void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc);
void RegisterVariable(std::size_t varIndex, std::string varName); void RegisterVariable(std::size_t varIndex, std::string varName);
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);

View File

@ -48,7 +48,6 @@ namespace Nz::ShaderAst
void Serialize(ReturnStatement& node); void Serialize(ReturnStatement& node);
protected: protected:
void Attributes(std::vector<Attribute>& attributes);
template<typename T> void Container(T& container); template<typename T> void Container(T& container);
template<typename T> void Enum(T& enumVal); template<typename T> void Enum(T& enumVal);
template<typename T> void OptEnum(std::optional<T>& optVal); template<typename T> void OptEnum(std::optional<T>& optVal);

View File

@ -10,6 +10,7 @@
#include <Nazara/Prerequisites.hpp> #include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderNodes.hpp> #include <Nazara/Shader/ShaderNodes.hpp>
#include <memory> #include <memory>
#include <optional>
namespace Nz::ShaderBuilder namespace Nz::ShaderBuilder
{ {
@ -59,13 +60,12 @@ namespace Nz::ShaderBuilder
struct DeclareFunction struct DeclareFunction
{ {
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const; inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const;
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const; inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::optional<ShaderStageType> entryStage, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const;
}; };
struct DeclareStruct struct DeclareStruct
{ {
inline std::unique_ptr<ShaderAst::DeclareStructStatement> operator()(ShaderAst::StructDescription description) const; inline std::unique_ptr<ShaderAst::DeclareStructStatement> operator()(ShaderAst::StructDescription description) const;
inline std::unique_ptr<ShaderAst::DeclareStructStatement> operator()(std::vector<ShaderAst::Attribute> attributes, ShaderAst::StructDescription description) const;
}; };
struct DeclareVariable struct DeclareVariable

View File

@ -108,10 +108,10 @@ namespace Nz::ShaderBuilder
return declareFunctionNode; return declareFunctionNode;
} }
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType) const inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::optional<ShaderStageType> entryStage, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType) const
{ {
auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>(); auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>();
declareFunctionNode->attributes = std::move(attributes); declareFunctionNode->entryStage = entryStage;
declareFunctionNode->name = std::move(name); declareFunctionNode->name = std::move(name);
declareFunctionNode->parameters = std::move(parameters); declareFunctionNode->parameters = std::move(parameters);
declareFunctionNode->returnType = std::move(returnType); declareFunctionNode->returnType = std::move(returnType);
@ -128,15 +128,6 @@ namespace Nz::ShaderBuilder
return declareStructNode; return declareStructNode;
} }
inline std::unique_ptr<ShaderAst::DeclareStructStatement> Impl::DeclareStruct::operator()(std::vector<ShaderAst::Attribute> attributes, ShaderAst::StructDescription description) const
{
auto declareStructNode = std::make_unique<ShaderAst::DeclareStructStatement>();
declareStructNode->attributes = std::move(attributes);
declareStructNode->description = std::move(description);
return declareStructNode;
}
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
{ {
auto declareVariableNode = std::make_unique<ShaderAst::DeclareVariableStatement>(); auto declareVariableNode = std::make_unique<ShaderAst::DeclareVariableStatement>();

View File

@ -25,14 +25,6 @@ namespace Nz::ShaderAst
Location //< Location (struct member only) - has argument index Location //< Location (struct member only) - has argument index
}; };
enum class PrimitiveType
{
Boolean, //< bool
Float32, //< f32
Int32, //< i32
UInt32, //< ui32
};
enum class BinaryType enum class BinaryType
{ {
Add, //< + Add, //< +
@ -80,6 +72,14 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/ShaderAstNodes.hpp> #include <Nazara/Shader/ShaderAstNodes.hpp>
}; };
enum class PrimitiveType
{
Boolean, //< bool
Float32, //< f32
Int32, //< i32
UInt32, //< ui32
};
enum class SwizzleComponent enum class SwizzleComponent
{ {
First, First,

View File

@ -14,6 +14,12 @@
namespace Nz::ShaderLang namespace Nz::ShaderLang
{ {
class AttributeError : public std::exception
{
public:
using exception::exception;
};
class ExpectedToken : public std::exception class ExpectedToken : public std::exception
{ {
public: public:

View File

@ -210,13 +210,12 @@ namespace Nz::ShaderAst
struct ExternalVar struct ExternalVar
{ {
std::optional<unsigned int> bindingIndex;
std::string name; std::string name;
std::vector<Attribute> attributes;
ExpressionType type; ExpressionType type;
}; };
std::optional<std::size_t> varIndex; std::optional<std::size_t> varIndex;
std::vector<Attribute> attributes;
std::vector<ExternalVar> externalVars; std::vector<ExternalVar> externalVars;
}; };
@ -231,10 +230,10 @@ namespace Nz::ShaderAst
ExpressionType type; ExpressionType type;
}; };
std::optional<ShaderStageType> entryStage;
std::optional<std::size_t> funcIndex; std::optional<std::size_t> funcIndex;
std::optional<std::size_t> varIndex; std::optional<std::size_t> varIndex;
std::string name; std::string name;
std::vector<Attribute> attributes;
std::vector<Parameter> parameters; std::vector<Parameter> parameters;
std::vector<StatementPtr> statements; std::vector<StatementPtr> statements;
ExpressionType returnType; ExpressionType returnType;
@ -246,7 +245,6 @@ namespace Nz::ShaderAst
void Visit(AstStatementVisitor& visitor) override; void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> structIndex; std::optional<std::size_t> structIndex;
std::vector<Attribute> attributes;
StructDescription description; StructDescription description;
}; };

View File

@ -24,12 +24,6 @@ namespace Nz
static const char* s_outputPrefix = "_NzOut_"; static const char* s_outputPrefix = "_NzOut_";
static const char* s_outputVarName = "_nzOutput"; static const char* s_outputVarName = "_nzOutput";
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id) template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id)
{ {
auto it = map.find(id); auto it = map.find(id);
@ -49,29 +43,19 @@ namespace Nz
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get()); ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
// Remove function if it's an entry point of another type than the one selected // Remove function if it's an entry point of another type than the one selected
bool isEntryPoint = false; if (node.entryStage)
for (auto& attribute : func->attributes)
{ {
if (attribute.type == ShaderAst::AttributeType::Entry) ShaderStageType stage = *node.entryStage;
{ if (stage != selectedStage)
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
if (it->second != selectedEntryPoint)
return ShaderBuilder::NoOp(); return ShaderBuilder::NoOp();
isEntryPoint = true;
break;
}
}
if (isEntryPoint)
entryPoint = func; entryPoint = func;
}
return clone; return clone;
} }
ShaderStageType selectedEntryPoint; ShaderStageType selectedStage;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
}; };
@ -81,8 +65,8 @@ namespace Nz
ShaderStageTypeFlags stageFlags; ShaderStageTypeFlags stageFlags;
}; };
std::unordered_map<std::string, Builtin> builtinMapping = { std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ "position", { "gl_Position", ShaderStageType::Vertex } } { ShaderAst::BuiltinEntry::VertexPosition, { "gl_Position", ShaderStageType::Vertex } }
}; };
} }
@ -95,17 +79,11 @@ namespace Nz
std::string targetName; std::string targetName;
}; };
struct StructInfo
{
ShaderAst::StructDescription structDesc;
bool isStd140 = false;
};
ShaderStageType stage; ShaderStageType stage;
const States* states = nullptr; const States* states = nullptr;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr; ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream; std::stringstream stream;
std::unordered_map<std::size_t, StructInfo> structs; std::unordered_map<std::size_t, ShaderAst::StructDescription> structs;
std::unordered_map<std::size_t, std::string> variableNames; std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields; std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields; std::vector<InOutField> outputFields;
@ -138,7 +116,7 @@ namespace Nz
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader); ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
PreVisitor previsitor; PreVisitor previsitor;
previsitor.selectedEntryPoint = shaderStage; previsitor.selectedStage = shaderStage;
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader); ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader);
@ -241,13 +219,13 @@ namespace Nz
void GlslWriter::Append(const ShaderAst::StructType& structType) void GlslWriter::Append(const ShaderAst::StructType& structType)
{ {
const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex).structDesc; const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex);
Append(structDesc.name); Append(structDesc.name);
} }
void GlslWriter::Append(const ShaderAst::UniformType& uniformType) void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
{ {
/* TODO */ throw std::runtime_error("unexpected UniformType");
} }
void GlslWriter::Append(const ShaderAst::VectorType& vecType) void GlslWriter::Append(const ShaderAst::VectorType& vecType)
@ -305,7 +283,7 @@ namespace Nz
void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers) void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers)
{ {
const auto& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc; const auto& structDesc = Retrieve(m_currentState->structs, structIndex);
const auto& member = structDesc.members[*memberIndices]; const auto& member = structDesc.members[*memberIndices];
@ -371,7 +349,7 @@ namespace Nz
assert(IsStructType(parameter.type)); assert(IsStructType(parameter.type));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex; std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc; const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex);
AppendLine(structDesc.name, " ", varName, ";"); AppendLine(structDesc.name, " ", varName, ";");
for (const auto& [memberName, targetName] : m_currentState->inputFields) for (const auto& [memberName, targetName] : m_currentState->inputFields)
@ -397,41 +375,24 @@ namespace Nz
{ {
for (const auto& member : structDesc.members) for (const auto& member : structDesc.members)
{ {
bool skip = false; if (member.builtin)
std::optional<std::string> builtinName;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = builtinMapping.find(std::get<std::string>(attributeParam));
if (it != builtinMapping.end())
{ {
auto it = s_builtinMapping.find(member.builtin.value());
assert(it != s_builtinMapping.end());
const Builtin& builtin = it->second; const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(m_currentState->stage)) if (!builtin.stageFlags.Test(m_currentState->stage))
{ continue; //< This builtin is not active in this stage, skip it
skip = true;
break;
}
builtinName = builtin.identifier; fields.push_back({
break; member.name,
builtin.identifier
});
} }
} else if (member.locationIndex)
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
if (skip)
continue;
if (attributeLocation)
{ {
Append("layout(location = "); Append("layout(location = ");
Append(*attributeLocation); Append(*member.locationIndex);
Append(") "); Append(") ");
Append(keyword); Append(keyword);
Append(" "); Append(" ");
@ -446,13 +407,6 @@ namespace Nz
targetPrefix + member.name targetPrefix + member.name
}); });
} }
else if (builtinName)
{
fields.push_back({
member.name,
*builtinName
});
}
} }
AppendLine(); AppendLine();
}; };
@ -468,7 +422,7 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type)); assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex; std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
inputStruct = &Retrieve(m_currentState->structs, inputStructIndex).structDesc; inputStruct = &Retrieve(m_currentState->structs, inputStructIndex);
AppendCommentSection("Inputs"); AppendCommentSection("Inputs");
AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix); AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix);
@ -485,20 +439,17 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType)); assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex; std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
const ShaderAst::StructDescription& outputStruct = Retrieve(m_currentState->structs, outputStructIndex).structDesc; const ShaderAst::StructDescription& outputStruct = Retrieve(m_currentState->structs, outputStructIndex);
AppendCommentSection("Outputs"); AppendCommentSection("Outputs");
AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix); AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix);
} }
} }
void GlslWriter::RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc) void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc)
{ {
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, State::StructInfo{ m_currentState->structs.emplace(structIndex, std::move(desc));
std::move(desc),
isStd140
});
} }
void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName) void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName)
@ -667,16 +618,6 @@ namespace Nz
for (const auto& externalVar : node.externalVars) for (const auto& externalVar : node.externalVars)
{ {
std::optional<long long> bindingIndex;
for (const auto& [attributeType, attributeParam] : externalVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
bindingIndex = std::get<long long>(attributeParam);
break;
}
}
bool isStd140 = false; bool isStd140 = false;
if (IsUniformType(externalVar.type)) if (IsUniformType(externalVar.type))
{ {
@ -685,13 +626,13 @@ namespace Nz
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex; std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = Retrieve(m_currentState->structs, structIndex); auto& structInfo = Retrieve(m_currentState->structs, structIndex);
isStd140 = structInfo.isStd140; isStd140 = structInfo.layout == StructLayout_Std140;
} }
if (bindingIndex) if (externalVar.bindingIndex)
{ {
Append("layout(binding = "); Append("layout(binding = ");
Append(*bindingIndex); Append(*externalVar.bindingIndex);
if (isStd140) if (isStd140)
Append(", std140"); Append(", std140");
@ -708,19 +649,19 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType)); assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex; std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = Retrieve(m_currentState->structs, structIndex); auto& structDesc = Retrieve(m_currentState->structs, structIndex);
bool first = true; bool first = true;
for (const auto& [name, attribute, type] : structInfo.structDesc.members) for (const auto& member : structDesc.members)
{ {
if (!first) if (!first)
AppendLine(); AppendLine();
first = false; first = false;
Append(type); Append(member.type);
Append(" "); Append(" ");
Append(name); Append(member.name);
Append(";"); Append(";");
} }
} }
@ -778,34 +719,24 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{ {
bool isStd140 = false;
for (const auto& [attributeType, attributeParam] : node.attributes)
{
if (attributeType == ShaderAst::AttributeType::Layout && std::get<std::string>(attributeParam) == "std140")
{
isStd140 = true;
break;
}
}
assert(node.structIndex); assert(node.structIndex);
RegisterStruct(*node.structIndex, isStd140, node.description); RegisterStruct(*node.structIndex, node.description);
Append("struct "); Append("struct ");
AppendLine(node.description.name); AppendLine(node.description.name);
EnterScope(); EnterScope();
{ {
bool first = true; bool first = true;
for (const auto& [name, attribute, type] : node.description.members) for (const auto& member : node.description.members)
{ {
if (!first) if (!first)
AppendLine(); AppendLine();
first = false; first = false;
Append(type); Append(member.type);
Append(" "); Append(" ");
Append(name); Append(member.name);
Append(";"); Append(";");
} }
} }
@ -897,7 +828,7 @@ namespace Nz
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr); const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
assert(IsStructType(returnType)); assert(IsStructType(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex; std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc; const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName; std::string outputStructVarName;
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression) if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)

View File

@ -45,7 +45,6 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareExternalStatement& node) StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
{ {
auto clone = std::make_unique<DeclareExternalStatement>(); auto clone = std::make_unique<DeclareExternalStatement>();
clone->attributes = node.attributes;
clone->externalVars = node.externalVars; clone->externalVars = node.externalVars;
clone->varIndex = node.varIndex; clone->varIndex = node.varIndex;
@ -55,7 +54,7 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node) StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
{ {
auto clone = std::make_unique<DeclareFunctionStatement>(); auto clone = std::make_unique<DeclareFunctionStatement>();
clone->attributes = node.attributes; clone->entryStage = node.entryStage;
clone->funcIndex = node.funcIndex; clone->funcIndex = node.funcIndex;
clone->name = node.name; clone->name = node.name;
clone->parameters = node.parameters; clone->parameters = node.parameters;

View File

@ -115,6 +115,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareVariableStatement& node) void AstSerializerBase::Serialize(DeclareVariableStatement& node)
{ {
OptVal(node.varIndex);
Value(node.varName); Value(node.varName);
Type(node.varType); Type(node.varType);
Node(node.initialExpression); Node(node.initialExpression);
@ -168,14 +169,14 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareExternalStatement& node) void AstSerializerBase::Serialize(DeclareExternalStatement& node)
{ {
Attributes(node.attributes); OptVal(node.varIndex);
Container(node.externalVars); Container(node.externalVars);
for (auto& extVar : node.externalVars) for (auto& extVar : node.externalVars)
{ {
Attributes(extVar.attributes);
Value(extVar.name); Value(extVar.name);
Type(extVar.type); Type(extVar.type);
OptVal(extVar.bindingIndex);
} }
} }
@ -183,8 +184,9 @@ namespace Nz::ShaderAst
{ {
Value(node.name); Value(node.name);
Type(node.returnType); Type(node.returnType);
OptEnum(node.entryStage);
Attributes(node.attributes); OptVal(node.funcIndex);
OptVal(node.varIndex);
Container(node.parameters); Container(node.parameters);
for (auto& parameter : node.parameters) for (auto& parameter : node.parameters)
@ -200,13 +202,18 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareStructStatement& node) void AstSerializerBase::Serialize(DeclareStructStatement& node)
{ {
OptVal(node.structIndex);
Value(node.description.name); Value(node.description.name);
OptEnum(node.description.layout);
Container(node.description.members); Container(node.description.members);
for (auto& member : node.description.members) for (auto& member : node.description.members)
{ {
Value(member.name); Value(member.name);
Type(member.type); Type(member.type);
OptEnum(member.builtin);
OptVal(member.locationIndex);
} }
} }
@ -246,78 +253,6 @@ namespace Nz::ShaderAst
m_stream.FlushBits(); m_stream.FlushBits();
} }
void AstSerializerBase::Attributes(std::vector<Attribute>& attributes)
{
Container(attributes);
for (auto& attribute : attributes)
{
Enum(attribute.type);
if (IsWriting())
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::monostate>)
{
UInt8 typeId = 0;
Value(typeId);
}
else if constexpr (std::is_same_v<T, long long>)
{
UInt8 typeId = 1;
UInt64 v = UInt64(arg);
Value(typeId);
Value(v);
}
else if constexpr (std::is_same_v<T, std::string>)
{
UInt8 typeId = 2;
Value(typeId);
Value(arg);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, attribute.args);
}
else
{
UInt8 typeId;
Value(typeId);
switch (typeId)
{
case 0:
attribute.args.emplace<std::monostate>();
break;
case 1:
{
UInt64 arg;
Value(arg);
attribute.args = static_cast<long long>(arg);
break;
}
case 2:
{
std::string arg;
Value(arg);
attribute.args = std::move(arg);
break;
}
default:
throw std::runtime_error("invalid attribute type id");
}
}
}
}
bool ShaderAstSerializer::IsWriting() const bool ShaderAstSerializer::IsWriting() const
{ {
return true; return true;

View File

@ -11,14 +11,6 @@
namespace Nz::ShaderAst namespace Nz::ShaderAst
{ {
namespace
{
std::unordered_map<std::string, ShaderStageType> entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
}
struct AstError struct AstError
{ {
std::string errMsg; std::string errMsg;
@ -28,7 +20,7 @@ namespace Nz::ShaderAst
{ {
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {}; std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar; std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<long long> usedBindingIndexes; std::unordered_set<unsigned int> usedBindingIndexes;
}; };
bool AstValidator::Validate(StatementPtr& node, std::string* error) bool AstValidator::Validate(StatementPtr& node, std::string* error)
@ -552,44 +544,15 @@ namespace Nz::ShaderAst
void AstValidator::Visit(DeclareExternalStatement& node) void AstValidator::Visit(DeclareExternalStatement& node)
{ {
if (!node.attributes.empty())
throw AstError{ "unhandled attribute for external block" };
/*for (const auto& [attributeType, arg] : node.attributes)
{
switch (attributeType)
{
default:
throw AstError{ "unhandled attribute for external block" };
}
}*/
for (const auto& extVar : node.externalVars) for (const auto& extVar : node.externalVars)
{ {
bool hasBinding = false; if (extVar.bindingIndex)
for (const auto& [attributeType, arg] : extVar.attributes)
{ {
switch (attributeType) unsigned int bindingIndex = extVar.bindingIndex.value();
{
case AttributeType::Binding:
{
if (hasBinding)
throw AstError{ "attribute binding must be present once" };
if (!std::holds_alternative<long long>(arg))
throw AstError{ "attribute binding requires a string parameter" };
long long bindingIndex = std::get<long long>(arg);
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end()) if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" }; throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
m_context->usedBindingIndexes.insert(bindingIndex); m_context->usedBindingIndexes.insert(bindingIndex);
break;
}
default:
throw AstError{ "unhandled attribute for external variable" };
}
} }
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end()) if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
@ -603,42 +566,17 @@ namespace Nz::ShaderAst
void AstValidator::Visit(DeclareFunctionStatement& node) void AstValidator::Visit(DeclareFunctionStatement& node)
{ {
bool hasEntry = false; if (node.entryStage)
for (const auto& [attributeType, arg] : node.attributes)
{ {
switch (attributeType) ShaderStageType stageType = *node.entryStage;
{
case AttributeType::Entry:
{
if (hasEntry)
throw AstError{ "attribute entry must be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute entry requires a string parameter" };
const std::string& argStr = std::get<std::string>(arg);
auto it = entryPoints.find(argStr);
if (it == entryPoints.end())
throw AstError{ "invalid parameter " + argStr + " for entry attribute" };
ShaderStageType stageType = it->second;
if (m_context->entryFunctions[UnderlyingCast(stageType)]) if (m_context->entryFunctions[UnderlyingCast(stageType)])
throw AstError{ "the same entry type has been defined multiple times" }; throw AstError{ "the same entry type has been defined multiple times" };
m_context->entryFunctions[UnderlyingCast(it->second)] = &node; m_context->entryFunctions[UnderlyingCast(stageType)] = &node;
if (node.parameters.size() > 1) if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" }; throw AstError{ "entry functions can either take one struct parameter or no parameter" };
hasEntry = true;
break;
}
default:
throw AstError{ "unhandled attribute for function" };
}
} }
for (auto& statement : node.statements) for (auto& statement : node.statements)

View File

@ -30,6 +30,24 @@ namespace Nz::ShaderLang
{ "layout", ShaderAst::AttributeType::Layout }, { "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location }, { "location", ShaderAst::AttributeType::Location },
}; };
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
std::unordered_map<std::string, ShaderAst::BuiltinEntry> s_builtinMapping = {
{ "position", ShaderAst::BuiltinEntry::VertexPosition }
};
template<typename T, typename U>
std::optional<T> BoundCast(U val)
{
if (val < std::numeric_limits<T>::min() || val > std::numeric_limits<T>::max())
return std::nullopt;
return static_cast<T>(val);
}
} }
ShaderAst::StatementPtr Parser::Parse(const std::vector<Token>& tokens) ShaderAst::StatementPtr Parser::Parse(const std::vector<Token>& tokens)
@ -305,14 +323,15 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes) ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
{ {
if (!attributes.empty())
throw AttributeError{ "unhandled attribute for external block" };
Expect(Advance(), TokenType::External); Expect(Advance(), TokenType::External);
Expect(Advance(), TokenType::OpenCurlyBracket); Expect(Advance(), TokenType::OpenCurlyBracket);
std::unique_ptr<ShaderAst::DeclareExternalStatement> externalStatement = std::make_unique<ShaderAst::DeclareExternalStatement>(); std::unique_ptr<ShaderAst::DeclareExternalStatement> externalStatement = std::make_unique<ShaderAst::DeclareExternalStatement>();
externalStatement->attributes = std::move(attributes);
bool first = true; bool first = true;
for (;;) for (;;)
{ {
if (!first) if (!first)
@ -336,7 +355,32 @@ namespace Nz::ShaderLang
auto& extVar = externalStatement->externalVars.emplace_back(); auto& extVar = externalStatement->externalVars.emplace_back();
if (token.type == TokenType::OpenAttribute) if (token.type == TokenType::OpenAttribute)
extVar.attributes = ParseAttributes(); {
for (const auto& [attributeType, arg] : ParseAttributes())
{
switch (attributeType)
{
case ShaderAst::AttributeType::Binding:
{
if (extVar.bindingIndex)
throw AttributeError{ "attribute binding must be present once" };
if (!std::holds_alternative<long long>(arg))
throw AttributeError{ "attribute binding requires a string parameter" };
std::optional<unsigned int> bindingIndex = BoundCast<unsigned int>(std::get<long long>(arg));
if (!bindingIndex)
throw AttributeError{ "invalid binding index" };
extVar.bindingIndex = bindingIndex.value();
break;
}
default:
throw AttributeError{ "unhandled attribute for external variable" };
}
}
}
extVar.name = ParseIdentifierAsName(); extVar.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon); Expect(Advance(), TokenType::Colon);
@ -402,7 +446,35 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingCurlyBracket); Expect(Advance(), TokenType::ClosingCurlyBracket);
return ShaderBuilder::DeclareFunction(std::move(attributes), std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType)); std::optional<ShaderStageType> entryPoint;
for (const auto& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Entry:
{
if (entryPoint)
throw AttributeError{ "attribute entry must be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AttributeError{ "attribute entry requires a string parameter" };
const std::string& argStr = std::get<std::string>(arg);
auto it = s_entryPoints.find(argStr);
if (it == s_entryPoints.end())
throw AttributeError{ ("invalid parameter " + argStr + " for entry attribute").c_str() };
entryPoint = it->second;
break;
}
default:
throw AttributeError{ "unhandled attribute for function" };
}
}
return ShaderBuilder::DeclareFunction(entryPoint, std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
} }
ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter() ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter()
@ -450,7 +522,41 @@ namespace Nz::ShaderLang
auto& structField = description.members.emplace_back(); auto& structField = description.members.emplace_back();
if (token.type == TokenType::OpenAttribute) if (token.type == TokenType::OpenAttribute)
structField.attributes = ParseAttributes(); {
for (const auto& [attributeType, attributeParam] : ParseAttributes())
{
switch (attributeType)
{
case ShaderAst::AttributeType::Builtin:
{
if (structField.builtin)
throw AttributeError{ "attribute builtin must be present once" };
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
if (it == s_builtinMapping.end())
throw AttributeError{ "unknown builtin" };
structField.builtin = it->second;
break;
}
case ShaderAst::AttributeType::Location:
{
if (structField.locationIndex)
throw AttributeError{ "attribute location must be present once" };
structField.locationIndex = BoundCast<unsigned int>(std::get<long long>(attributeParam));
if (!structField.locationIndex)
throw AttributeError{ "invalid location index" };
break;
}
}
}
if (structField.builtin && structField.locationIndex)
throw AttributeError{ "A struct field cannot have both builtin and location attributes" };
}
structField.name = ParseIdentifierAsName(); structField.name = ParseIdentifierAsName();
@ -461,7 +567,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingCurlyBracket); Expect(Advance(), TokenType::ClosingCurlyBracket);
return ShaderBuilder::DeclareStruct(std::move(attributes), std::move(description)); return ShaderBuilder::DeclareStruct(std::move(description));
} }
ShaderAst::StatementPtr Parser::ParseReturnStatement() ShaderAst::StatementPtr Parser::ParseReturnStatement()

View File

@ -27,12 +27,6 @@ namespace Nz
{ {
namespace namespace
{ {
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
struct Builtin struct Builtin
{ {
const char* debugName; const char* debugName;
@ -40,8 +34,8 @@ namespace Nz
SpirvBuiltIn decoration; SpirvBuiltIn decoration;
}; };
std::unordered_map<std::string, Builtin> s_builtinMapping = { std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } } { ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
}; };
class PreVisitor : public ShaderAst::AstScopedVisitor class PreVisitor : public ShaderAst::AstScopedVisitor
@ -129,32 +123,13 @@ namespace Nz
UniformVar& uniformVar = extVars[varIndex++]; UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable); uniformVar.pointerId = m_constantCache.Register(variable);
uniformVar.bindingIndex = extVar.bindingIndex;
for (const auto& [attributeType, attributeParam] : extVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
uniformVar.bindingIndex = std::get<long long>(attributeParam);
break;
}
}
} }
} }
void Visit(ShaderAst::DeclareFunctionStatement& node) override void Visit(ShaderAst::DeclareFunctionStatement& node) override
{ {
std::optional<ShaderStageType> entryPointType; std::optional<ShaderStageType> entryPointType = node.entryStage;
for (auto& attribute : node.attributes)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
entryPointType = it->second;
break;
}
}
assert(node.funcIndex); assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex; std::size_t funcIndex = *node.funcIndex;
@ -325,29 +300,12 @@ namespace Nz
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass) UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
{ {
std::optional<std::reference_wrapper<Builtin>> builtinOpt; if (member.builtin)
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
{ {
if (attributeType == ShaderAst::AttributeType::Builtin) auto it = s_builtinMapping.find(*member.builtin);
{ assert(it != s_builtinMapping.end());
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
if (it != s_builtinMapping.end())
{
builtinOpt = it->second;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
if (builtinOpt) Builtin& builtin = it->second;
{
Builtin& builtin = *builtinOpt;
if ((builtin.compatibleStages & entryPointType) == 0) if ((builtin.compatibleStages & entryPointType) == 0)
return 0; return 0;
@ -364,7 +322,7 @@ namespace Nz
return varId; return varId;
} }
else if (attributeLocation) else if (member.locationIndex)
{ {
SpirvConstantCache::Variable variable; SpirvConstantCache::Variable variable;
variable.debugName = member.name; variable.debugName = member.name;
@ -373,7 +331,7 @@ namespace Nz
variable.type = m_constantCache.BuildPointerType(member.type, storageClass); variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable); UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = *attributeLocation; locationDecorations[varId] = *member.locationIndex;
return varId; return varId;
} }