Shader: Move attribute parsing to parser, simplifying writer code

This commit is contained in:
Jérôme Leclercq 2021-04-14 11:34:21 +02:00
parent bca1561f73
commit aababb205f
14 changed files with 910 additions and 1046 deletions

View File

@ -11,6 +11,7 @@
#include <Nazara/Utility/Enums.hpp>
#include <Nazara/Shader/ShaderEnums.hpp>
#include <Nazara/Shader/Ast/Attribute.hpp>
#include <optional>
#include <string>
#include <variant>
#include <vector>
@ -81,11 +82,13 @@ namespace Nz::ShaderAst
{
struct StructMember
{
std::optional<BuiltinEntry> builtin;
std::optional<unsigned int> locationIndex;
std::string name;
std::vector<Attribute> attributes;
ExpressionType type;
};
std::optional<StructLayout> layout;
std::string name;
std::vector<StructMember> members;
};

View File

@ -71,7 +71,7 @@ namespace Nz
void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node);
void HandleInOut();
void RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc);
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc);
void RegisterVariable(std::size_t varIndex, std::string varName);
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);

View File

@ -48,7 +48,6 @@ namespace Nz::ShaderAst
void Serialize(ReturnStatement& node);
protected:
void Attributes(std::vector<Attribute>& attributes);
template<typename T> void Container(T& container);
template<typename T> void Enum(T& enumVal);
template<typename T> void OptEnum(std::optional<T>& optVal);

View File

@ -10,6 +10,7 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/ShaderNodes.hpp>
#include <memory>
#include <optional>
namespace Nz::ShaderBuilder
{
@ -59,13 +60,12 @@ namespace Nz::ShaderBuilder
struct DeclareFunction
{
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const;
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const;
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::optional<ShaderStageType> entryStage, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType = ShaderAst::NoType{}) const;
};
struct DeclareStruct
{
inline std::unique_ptr<ShaderAst::DeclareStructStatement> operator()(ShaderAst::StructDescription description) const;
inline std::unique_ptr<ShaderAst::DeclareStructStatement> operator()(std::vector<ShaderAst::Attribute> attributes, ShaderAst::StructDescription description) const;
};
struct DeclareVariable

View File

@ -108,10 +108,10 @@ namespace Nz::ShaderBuilder
return declareFunctionNode;
}
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::vector<ShaderAst::Attribute> attributes, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType) const
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::optional<ShaderStageType> entryStage, std::string name, std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters, std::vector<ShaderAst::StatementPtr> statements, ShaderAst::ExpressionType returnType) const
{
auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>();
declareFunctionNode->attributes = std::move(attributes);
declareFunctionNode->entryStage = entryStage;
declareFunctionNode->name = std::move(name);
declareFunctionNode->parameters = std::move(parameters);
declareFunctionNode->returnType = std::move(returnType);
@ -128,15 +128,6 @@ namespace Nz::ShaderBuilder
return declareStructNode;
}
inline std::unique_ptr<ShaderAst::DeclareStructStatement> Impl::DeclareStruct::operator()(std::vector<ShaderAst::Attribute> attributes, ShaderAst::StructDescription description) const
{
auto declareStructNode = std::make_unique<ShaderAst::DeclareStructStatement>();
declareStructNode->attributes = std::move(attributes);
declareStructNode->description = std::move(description);
return declareStructNode;
}
inline std::unique_ptr<ShaderAst::DeclareVariableStatement> Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
{
auto declareVariableNode = std::make_unique<ShaderAst::DeclareVariableStatement>();

View File

@ -25,14 +25,6 @@ namespace Nz::ShaderAst
Location //< Location (struct member only) - has argument index
};
enum class PrimitiveType
{
Boolean, //< bool
Float32, //< f32
Int32, //< i32
UInt32, //< ui32
};
enum class BinaryType
{
Add, //< +
@ -80,6 +72,14 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/ShaderAstNodes.hpp>
};
enum class PrimitiveType
{
Boolean, //< bool
Float32, //< f32
Int32, //< i32
UInt32, //< ui32
};
enum class SwizzleComponent
{
First,

View File

@ -14,6 +14,12 @@
namespace Nz::ShaderLang
{
class AttributeError : public std::exception
{
public:
using exception::exception;
};
class ExpectedToken : public std::exception
{
public:

View File

@ -210,13 +210,12 @@ namespace Nz::ShaderAst
struct ExternalVar
{
std::optional<unsigned int> bindingIndex;
std::string name;
std::vector<Attribute> attributes;
ExpressionType type;
};
std::optional<std::size_t> varIndex;
std::vector<Attribute> attributes;
std::vector<ExternalVar> externalVars;
};
@ -231,10 +230,10 @@ namespace Nz::ShaderAst
ExpressionType type;
};
std::optional<ShaderStageType> entryStage;
std::optional<std::size_t> funcIndex;
std::optional<std::size_t> varIndex;
std::string name;
std::vector<Attribute> attributes;
std::vector<Parameter> parameters;
std::vector<StatementPtr> statements;
ExpressionType returnType;
@ -246,7 +245,6 @@ namespace Nz::ShaderAst
void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> structIndex;
std::vector<Attribute> attributes;
StructDescription description;
};

View File

@ -24,12 +24,6 @@ namespace Nz
static const char* s_outputPrefix = "_NzOut_";
static const char* s_outputVarName = "_nzOutput";
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id)
{
auto it = map.find(id);
@ -49,29 +43,19 @@ namespace Nz
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
// Remove function if it's an entry point of another type than the one selected
bool isEntryPoint = false;
for (auto& attribute : func->attributes)
if (node.entryStage)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
ShaderStageType stage = *node.entryStage;
if (stage != selectedStage)
return ShaderBuilder::NoOp();
if (it->second != selectedEntryPoint)
return ShaderBuilder::NoOp();
isEntryPoint = true;
break;
}
}
if (isEntryPoint)
entryPoint = func;
}
return clone;
}
ShaderStageType selectedEntryPoint;
ShaderStageType selectedStage;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
};
@ -81,8 +65,8 @@ namespace Nz
ShaderStageTypeFlags stageFlags;
};
std::unordered_map<std::string, Builtin> builtinMapping = {
{ "position", { "gl_Position", ShaderStageType::Vertex } }
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ ShaderAst::BuiltinEntry::VertexPosition, { "gl_Position", ShaderStageType::Vertex } }
};
}
@ -95,17 +79,11 @@ namespace Nz
std::string targetName;
};
struct StructInfo
{
ShaderAst::StructDescription structDesc;
bool isStd140 = false;
};
ShaderStageType stage;
const States* states = nullptr;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
std::unordered_map<std::size_t, StructInfo> structs;
std::unordered_map<std::size_t, ShaderAst::StructDescription> structs;
std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields;
@ -138,7 +116,7 @@ namespace Nz
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
PreVisitor previsitor;
previsitor.selectedEntryPoint = shaderStage;
previsitor.selectedStage = shaderStage;
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader);
@ -241,13 +219,13 @@ namespace Nz
void GlslWriter::Append(const ShaderAst::StructType& structType)
{
const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex).structDesc;
const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex);
Append(structDesc.name);
}
void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
{
/* TODO */
throw std::runtime_error("unexpected UniformType");
}
void GlslWriter::Append(const ShaderAst::VectorType& vecType)
@ -305,7 +283,7 @@ namespace Nz
void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers)
{
const auto& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc;
const auto& structDesc = Retrieve(m_currentState->structs, structIndex);
const auto& member = structDesc.members[*memberIndices];
@ -371,7 +349,7 @@ namespace Nz
assert(IsStructType(parameter.type));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex);
AppendLine(structDesc.name, " ", varName, ";");
for (const auto& [memberName, targetName] : m_currentState->inputFields)
@ -397,41 +375,24 @@ namespace Nz
{
for (const auto& member : structDesc.members)
{
bool skip = false;
std::optional<std::string> builtinName;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
if (member.builtin)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = builtinMapping.find(std::get<std::string>(attributeParam));
if (it != builtinMapping.end())
{
const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(m_currentState->stage))
{
skip = true;
break;
}
auto it = s_builtinMapping.find(member.builtin.value());
assert(it != s_builtinMapping.end());
builtinName = builtin.identifier;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(m_currentState->stage))
continue; //< This builtin is not active in this stage, skip it
fields.push_back({
member.name,
builtin.identifier
});
}
if (skip)
continue;
if (attributeLocation)
else if (member.locationIndex)
{
Append("layout(location = ");
Append(*attributeLocation);
Append(*member.locationIndex);
Append(") ");
Append(keyword);
Append(" ");
@ -446,13 +407,6 @@ namespace Nz
targetPrefix + member.name
});
}
else if (builtinName)
{
fields.push_back({
member.name,
*builtinName
});
}
}
AppendLine();
};
@ -468,7 +422,7 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
inputStruct = &Retrieve(m_currentState->structs, inputStructIndex).structDesc;
inputStruct = &Retrieve(m_currentState->structs, inputStructIndex);
AppendCommentSection("Inputs");
AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix);
@ -485,20 +439,17 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
const ShaderAst::StructDescription& outputStruct = Retrieve(m_currentState->structs, outputStructIndex).structDesc;
const ShaderAst::StructDescription& outputStruct = Retrieve(m_currentState->structs, outputStructIndex);
AppendCommentSection("Outputs");
AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix);
}
}
void GlslWriter::RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc)
void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc)
{
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, State::StructInfo{
std::move(desc),
isStd140
});
m_currentState->structs.emplace(structIndex, std::move(desc));
}
void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName)
@ -667,16 +618,6 @@ namespace Nz
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
for (const auto& [attributeType, attributeParam] : externalVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
bindingIndex = std::get<long long>(attributeParam);
break;
}
}
bool isStd140 = false;
if (IsUniformType(externalVar.type))
{
@ -685,13 +626,13 @@ namespace Nz
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = Retrieve(m_currentState->structs, structIndex);
isStd140 = structInfo.isStd140;
isStd140 = structInfo.layout == StructLayout_Std140;
}
if (bindingIndex)
if (externalVar.bindingIndex)
{
Append("layout(binding = ");
Append(*bindingIndex);
Append(*externalVar.bindingIndex);
if (isStd140)
Append(", std140");
@ -708,19 +649,19 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = Retrieve(m_currentState->structs, structIndex);
auto& structDesc = Retrieve(m_currentState->structs, structIndex);
bool first = true;
for (const auto& [name, attribute, type] : structInfo.structDesc.members)
for (const auto& member : structDesc.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(member.type);
Append(" ");
Append(name);
Append(member.name);
Append(";");
}
}
@ -778,34 +719,24 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
bool isStd140 = false;
for (const auto& [attributeType, attributeParam] : node.attributes)
{
if (attributeType == ShaderAst::AttributeType::Layout && std::get<std::string>(attributeParam) == "std140")
{
isStd140 = true;
break;
}
}
assert(node.structIndex);
RegisterStruct(*node.structIndex, isStd140, node.description);
RegisterStruct(*node.structIndex, node.description);
Append("struct ");
AppendLine(node.description.name);
EnterScope();
{
bool first = true;
for (const auto& [name, attribute, type] : node.description.members)
for (const auto& member : node.description.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(member.type);
Append(" ");
Append(name);
Append(member.name);
Append(";");
}
}
@ -897,7 +828,7 @@ namespace Nz
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
assert(IsStructType(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex).structDesc;
const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName;
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)

View File

@ -45,7 +45,6 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareExternalStatement>();
clone->attributes = node.attributes;
clone->externalVars = node.externalVars;
clone->varIndex = node.varIndex;
@ -55,7 +54,7 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareFunctionStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->attributes = node.attributes;
clone->entryStage = node.entryStage;
clone->funcIndex = node.funcIndex;
clone->name = node.name;
clone->parameters = node.parameters;

View File

@ -115,6 +115,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareVariableStatement& node)
{
OptVal(node.varIndex);
Value(node.varName);
Type(node.varType);
Node(node.initialExpression);
@ -168,14 +169,14 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareExternalStatement& node)
{
Attributes(node.attributes);
OptVal(node.varIndex);
Container(node.externalVars);
for (auto& extVar : node.externalVars)
{
Attributes(extVar.attributes);
Value(extVar.name);
Type(extVar.type);
OptVal(extVar.bindingIndex);
}
}
@ -183,8 +184,9 @@ namespace Nz::ShaderAst
{
Value(node.name);
Type(node.returnType);
Attributes(node.attributes);
OptEnum(node.entryStage);
OptVal(node.funcIndex);
OptVal(node.varIndex);
Container(node.parameters);
for (auto& parameter : node.parameters)
@ -200,13 +202,18 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareStructStatement& node)
{
OptVal(node.structIndex);
Value(node.description.name);
OptEnum(node.description.layout);
Container(node.description.members);
for (auto& member : node.description.members)
{
Value(member.name);
Type(member.type);
OptEnum(member.builtin);
OptVal(member.locationIndex);
}
}
@ -246,78 +253,6 @@ namespace Nz::ShaderAst
m_stream.FlushBits();
}
void AstSerializerBase::Attributes(std::vector<Attribute>& attributes)
{
Container(attributes);
for (auto& attribute : attributes)
{
Enum(attribute.type);
if (IsWriting())
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::monostate>)
{
UInt8 typeId = 0;
Value(typeId);
}
else if constexpr (std::is_same_v<T, long long>)
{
UInt8 typeId = 1;
UInt64 v = UInt64(arg);
Value(typeId);
Value(v);
}
else if constexpr (std::is_same_v<T, std::string>)
{
UInt8 typeId = 2;
Value(typeId);
Value(arg);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, attribute.args);
}
else
{
UInt8 typeId;
Value(typeId);
switch (typeId)
{
case 0:
attribute.args.emplace<std::monostate>();
break;
case 1:
{
UInt64 arg;
Value(arg);
attribute.args = static_cast<long long>(arg);
break;
}
case 2:
{
std::string arg;
Value(arg);
attribute.args = std::move(arg);
break;
}
default:
throw std::runtime_error("invalid attribute type id");
}
}
}
}
bool ShaderAstSerializer::IsWriting() const
{
return true;

View File

@ -11,14 +11,6 @@
namespace Nz::ShaderAst
{
namespace
{
std::unordered_map<std::string, ShaderStageType> entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
}
struct AstError
{
std::string errMsg;
@ -28,7 +20,7 @@ namespace Nz::ShaderAst
{
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<long long> usedBindingIndexes;
std::unordered_set<unsigned int> usedBindingIndexes;
};
bool AstValidator::Validate(StatementPtr& node, std::string* error)
@ -552,44 +544,15 @@ namespace Nz::ShaderAst
void AstValidator::Visit(DeclareExternalStatement& node)
{
if (!node.attributes.empty())
throw AstError{ "unhandled attribute for external block" };
/*for (const auto& [attributeType, arg] : node.attributes)
{
switch (attributeType)
{
default:
throw AstError{ "unhandled attribute for external block" };
}
}*/
for (const auto& extVar : node.externalVars)
{
bool hasBinding = false;
for (const auto& [attributeType, arg] : extVar.attributes)
if (extVar.bindingIndex)
{
switch (attributeType)
{
case AttributeType::Binding:
{
if (hasBinding)
throw AstError{ "attribute binding must be present once" };
unsigned int bindingIndex = extVar.bindingIndex.value();
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
if (!std::holds_alternative<long long>(arg))
throw AstError{ "attribute binding requires a string parameter" };
long long bindingIndex = std::get<long long>(arg);
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
m_context->usedBindingIndexes.insert(bindingIndex);
break;
}
default:
throw AstError{ "unhandled attribute for external variable" };
}
m_context->usedBindingIndexes.insert(bindingIndex);
}
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
@ -603,42 +566,17 @@ namespace Nz::ShaderAst
void AstValidator::Visit(DeclareFunctionStatement& node)
{
bool hasEntry = false;
for (const auto& [attributeType, arg] : node.attributes)
if (node.entryStage)
{
switch (attributeType)
{
case AttributeType::Entry:
{
if (hasEntry)
throw AstError{ "attribute entry must be present once" };
ShaderStageType stageType = *node.entryStage;
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute entry requires a string parameter" };
if (m_context->entryFunctions[UnderlyingCast(stageType)])
throw AstError{ "the same entry type has been defined multiple times" };
const std::string& argStr = std::get<std::string>(arg);
m_context->entryFunctions[UnderlyingCast(stageType)] = &node;
auto it = entryPoints.find(argStr);
if (it == entryPoints.end())
throw AstError{ "invalid parameter " + argStr + " for entry attribute" };
ShaderStageType stageType = it->second;
if (m_context->entryFunctions[UnderlyingCast(stageType)])
throw AstError{ "the same entry type has been defined multiple times" };
m_context->entryFunctions[UnderlyingCast(it->second)] = &node;
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
hasEntry = true;
break;
}
default:
throw AstError{ "unhandled attribute for function" };
}
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
}
for (auto& statement : node.statements)

File diff suppressed because it is too large Load Diff

View File

@ -27,12 +27,6 @@ namespace Nz
{
namespace
{
//FIXME: Have this only once
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
struct Builtin
{
const char* debugName;
@ -40,8 +34,8 @@ namespace Nz
SpirvBuiltIn decoration;
};
std::unordered_map<std::string, Builtin> s_builtinMapping = {
{ "position", { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } }
};
class PreVisitor : public ShaderAst::AstScopedVisitor
@ -129,32 +123,13 @@ namespace Nz
UniformVar& uniformVar = extVars[varIndex++];
uniformVar.pointerId = m_constantCache.Register(variable);
for (const auto& [attributeType, attributeParam] : extVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
{
uniformVar.bindingIndex = std::get<long long>(attributeParam);
break;
}
}
uniformVar.bindingIndex = extVar.bindingIndex;
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
std::optional<ShaderStageType> entryPointType;
for (auto& attribute : node.attributes)
{
if (attribute.type == ShaderAst::AttributeType::Entry)
{
auto it = s_entryPoints.find(std::get<std::string>(attribute.args));
assert(it != s_entryPoints.end());
entryPointType = it->second;
break;
}
}
std::optional<ShaderStageType> entryPointType = node.entryStage;
assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex;
@ -325,29 +300,12 @@ namespace Nz
UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass)
{
std::optional<std::reference_wrapper<Builtin>> builtinOpt;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
if (member.builtin)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = s_builtinMapping.find(std::get<std::string>(attributeParam));
if (it != s_builtinMapping.end())
{
builtinOpt = it->second;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
auto it = s_builtinMapping.find(*member.builtin);
assert(it != s_builtinMapping.end());
if (builtinOpt)
{
Builtin& builtin = *builtinOpt;
Builtin& builtin = it->second;
if ((builtin.compatibleStages & entryPointType) == 0)
return 0;
@ -364,7 +322,7 @@ namespace Nz
return varId;
}
else if (attributeLocation)
else if (member.locationIndex)
{
SpirvConstantCache::Variable variable;
variable.debugName = member.name;
@ -373,7 +331,7 @@ namespace Nz
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = *attributeLocation;
locationDecorations[varId] = *member.locationIndex;
return varId;
}