diff --git a/include/Nazara/Shader/ShaderAstCache.hpp b/include/Nazara/Shader/ShaderAstCache.hpp index 595b6b410..6e1b0a144 100644 --- a/include/Nazara/Shader/ShaderAstCache.hpp +++ b/include/Nazara/Shader/ShaderAstCache.hpp @@ -37,6 +37,7 @@ namespace Nz::ShaderAst inline std::size_t GetScopeId(const Node* node) const; ShaderStageType stageType = ShaderStageType::Undefined; + std::array entryFunctions = {}; std::unordered_map nodeExpressionType; std::unordered_map scopeIdByNode; std::vector scopes; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 6a18b25ed..625875470 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -34,6 +34,7 @@ namespace Nz::ShaderBuilder struct DeclareFunction { inline std::unique_ptr operator()(std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const; + inline std::unique_ptr operator()(std::vector attributes, std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void) const; }; struct DeclareVariable diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index ef7e89849..31fb4dcf9 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -58,6 +58,18 @@ namespace Nz::ShaderBuilder return declareFunctionNode; } + inline std::unique_ptr Impl::DeclareFunction::operator()(std::vector attributes, std::string name, std::vector parameters, std::vector statements, ShaderAst::ShaderExpressionType returnType) const + { + auto declareFunctionNode = std::make_unique(); + declareFunctionNode->attributes = std::move(attributes); + declareFunctionNode->name = std::move(name); + declareFunctionNode->parameters = std::move(parameters); + declareFunctionNode->returnType = std::move(returnType); + declareFunctionNode->statements = std::move(statements); + + return declareFunctionNode; + } + inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ShaderExpressionType type, ShaderAst::ExpressionPtr initialValue) const { auto declareVariableNode = std::make_unique(); diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 785136702..ba727b73a 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -57,14 +57,15 @@ namespace Nz::ShaderLang const Token& Advance(); void Consume(std::size_t count = 1); const Token& Expect(const Token& token, TokenType type); + const Token& ExpectNot(const Token& token, TokenType type); const Token& Expect(TokenType type); const Token& Peek(std::size_t advance = 0); - std::vector ParseAttributes(); + void HandleAttributes(); // Statements std::vector ParseFunctionBody(); - ShaderAst::StatementPtr ParseFunctionDeclaration(); + ShaderAst::StatementPtr ParseFunctionDeclaration(std::vector attributes = {}); ShaderAst::DeclareFunctionStatement::Parameter ParseFunctionParameter(); ShaderAst::StatementPtr ParseReturnStatement(); ShaderAst::StatementPtr ParseStatement(); @@ -87,7 +88,6 @@ namespace Nz::ShaderLang struct Context { - std::vector pendingAttributes; std::unique_ptr root; std::size_t tokenCount; std::size_t tokenIndex = 0; diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index 4e532ad14..386f53168 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -201,6 +201,7 @@ namespace Nz::ShaderAst }; std::string name; + std::vector attributes; std::vector parameters; std::vector statements; ShaderExpressionType returnType = BasicType::Void; diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index 2948cd753..ea7ed0276 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -165,6 +165,7 @@ namespace Nz::ShaderAst void AstCloner::Visit(DeclareFunctionStatement& node) { auto clone = std::make_unique(); + clone->attributes = node.attributes; clone->name = node.name; clone->parameters = node.parameters; clone->returnType = node.returnType; diff --git a/src/Nazara/Shader/ShaderAstExpressionType.cpp b/src/Nazara/Shader/ShaderAstExpressionType.cpp index 9a6b394d7..71f831f69 100644 --- a/src/Nazara/Shader/ShaderAstExpressionType.cpp +++ b/src/Nazara/Shader/ShaderAstExpressionType.cpp @@ -192,7 +192,7 @@ namespace Nz::ShaderAst void ExpressionTypeVisitor::Visit(SwizzleExpression& node) { - const ShaderExpressionType& exprType = GetExpressionTypeInternal(*node.expression); + ShaderExpressionType exprType = GetExpressionTypeInternal(*node.expression); assert(IsBasicType(exprType)); m_lastExpressionType = static_cast(UnderlyingCast(GetComponentType(std::get(exprType))) + node.componentCount - 1); diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index 4100a0204..b078a53f8 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -157,6 +157,13 @@ namespace Nz::ShaderAst Value(node.name); Type(node.returnType); + Container(node.attributes); + for (auto& attribute : node.attributes) + { + Enum(attribute.type); + Value(attribute.args); + } + Container(node.parameters); for (auto& parameter : node.parameters) { diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 0e5d8f52f..2c41cf999 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -6,11 +6,21 @@ #include #include #include +#include #include #include namespace Nz::ShaderAst { + namespace + { + std::unordered_map entryPoints = { + { "frag", ShaderStageType::Fragment }, + { "vert", ShaderStageType::Vertex }, + }; + + } + struct AstError { std::string errMsg; @@ -135,7 +145,7 @@ namespace Nz::ShaderAst { RegisterScope(node); - const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache); + ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache); if (!IsStructType(exprType)) throw AstError{ "expression is not a structure" }; @@ -165,11 +175,11 @@ namespace Nz::ShaderAst // Register expression type AstRecursiveVisitor::Visit(node); - const ShaderExpressionType& leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache); + ShaderExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache); if (!IsBasicType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; - const ShaderExpressionType& rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache); + ShaderExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache); if (!IsBasicType(rightExprType)) throw AstError{ "right expression type does not support binary operation" }; @@ -349,7 +359,7 @@ namespace Nz::ShaderAst if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; - const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache); + ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache); if (!IsBasicType(exprType)) throw AstError{ "Cannot swizzle this type" }; @@ -378,7 +388,7 @@ namespace Nz::ShaderAst for (auto& condStatement : node.condStatements) { - const ShaderExpressionType& condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache); + ShaderExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache); if (!IsBasicType(condType) || std::get(condType) != BasicType::Boolean) throw AstError{ "if expression must resolve to boolean type" }; @@ -401,8 +411,40 @@ namespace Nz::ShaderAst void AstValidator::Visit(DeclareFunctionStatement& node) { - auto& scope = EnterScope(); + bool hasEntry = false; + for (const auto& [attributeType, arg] : node.attributes) + { + switch (attributeType) + { + case AttributeType::Entry: + { + if (hasEntry) + throw AstError{ "attribute entry must be present once" }; + if (arg.empty()) + throw AstError{ "attribute entry requires a parameter" }; + + auto it = entryPoints.find(arg); + if (it == entryPoints.end()) + throw AstError{ "invalid parameter " + arg + " for entry attribute" }; + + ShaderStageType stageType = it->second; + + if (m_context->cache->entryFunctions[UnderlyingCast(stageType)]) + throw AstError{ "the same entry type has been defined multiple times" }; + + m_context->cache->entryFunctions[UnderlyingCast(it->second)] = &node; + + hasEntry = true; + break; + } + + default: + throw AstError{ "unhandled attribute for function" }; + } + } + + auto& scope = EnterScope(); RegisterScope(node); for (auto& parameter : node.parameters) diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 4cd73eedc..c2a05aad3 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -57,7 +57,7 @@ namespace Nz::ShaderLang switch (nextToken.type) { case TokenType::OpenAttribute: - context.pendingAttributes = ParseAttributes(); + HandleAttributes(); break; case TokenType::FunctionDeclaration: @@ -98,6 +98,14 @@ namespace Nz::ShaderLang return token; } + const Token& Parser::ExpectNot(const Token& token, TokenType type) + { + if (token.type == type) + throw ExpectedToken{}; + + return token; + } + const Token& Parser::Expect(TokenType type) { const Token& token = Peek(); @@ -112,7 +120,7 @@ namespace Nz::ShaderLang return m_context->tokens[m_context->tokenIndex + advance]; } - std::vector Parser::ParseAttributes() + void Parser::HandleAttributes() { std::vector attributes; @@ -122,6 +130,8 @@ namespace Nz::ShaderLang for (;;) { const Token& t = Peek(); + ExpectNot(t, TokenType::EndOfStream); + if (t.type == TokenType::ClosingAttribute) { // Parse [[attribute1]] [[attribute2]] the same as [[attribute1, attribute2]] @@ -161,7 +171,16 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::ClosingAttribute); - return attributes; + const Token& nextToken = Peek(); + switch (nextToken.type) + { + case TokenType::FunctionDeclaration: + m_context->root->statements.push_back(ParseFunctionDeclaration(std::move(attributes))); + break; + + default: + throw UnexpectedToken{}; + } } std::vector Parser::ParseFunctionBody() @@ -169,7 +188,7 @@ namespace Nz::ShaderLang return ParseStatementList(); } - ShaderAst::StatementPtr Parser::ParseFunctionDeclaration() + ShaderAst::StatementPtr Parser::ParseFunctionDeclaration(std::vector attributes) { Expect(Advance(), TokenType::FunctionDeclaration); @@ -183,6 +202,8 @@ namespace Nz::ShaderLang for (;;) { const Token& t = Peek(); + ExpectNot(t, TokenType::EndOfStream); + if (t.type == TokenType::ClosingParenthesis) break; @@ -208,7 +229,7 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::ClosingCurlyBracket); - return ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType)); + return ShaderBuilder::DeclareFunction(std::move(attributes), std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType)); } ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter() @@ -262,6 +283,7 @@ namespace Nz::ShaderLang std::vector statements; while (Peek().type != TokenType::ClosingCurlyBracket) { + ExpectNot(Peek(), TokenType::EndOfStream); statements.push_back(ParseStatement()); } @@ -293,6 +315,7 @@ namespace Nz::ShaderLang for (;;) { const Token& currentOp = Peek(); + ExpectNot(currentOp, TokenType::EndOfStream); int tokenPrecedence = GetTokenPrecedence(currentOp.type); if (tokenPrecedence < exprPrecedence) diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index c877db5a0..eb52edeca 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -172,6 +171,7 @@ namespace Nz struct Func { + const ShaderAst::DeclareFunctionStatement* statement = nullptr; UInt32 typeId; UInt32 id; }; @@ -355,6 +355,7 @@ namespace Nz for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs) { auto& funcData = state.funcs.emplace_back(); + funcData.statement = &func; funcData.id = AllocateResultId(); funcData.typeId = GetFunctionTypeId(func); @@ -407,14 +408,21 @@ namespace Nz AppendHeader(); - /*if (entryPointIndex != std::numeric_limits::max()) + for (std::size_t i = 0; i < ShaderStageTypeCount; ++i) { - SpvExecutionModel execModel; - const auto& entryFuncData = shader.GetFunction(entryPointIndex); - const auto& entryFunc = state.funcs[entryPointIndex]; + const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i]; + if (!statement) + continue; - assert(m_context.shader); - switch (m_context.shader->GetStage()) + auto it = std::find_if(state.funcs.begin(), state.funcs.end(), [&](const auto& funcData) { return funcData.statement == statement; }); + assert(it != state.funcs.end()); + + const auto& entryFunc = *it; + + SpirvExecutionModel execModel; + + ShaderStageType stage = static_cast(i); + switch (stage) { case ShaderStageType::Fragment: execModel = SpirvExecutionModel::Fragment; @@ -427,14 +435,12 @@ namespace Nz default: throw std::runtime_error("not yet implemented"); } - - // OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos - + state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender) { appender(execModel); appender(entryFunc.id); - appender(entryFuncData.name); + appender(statement->name); for (const auto& [name, varData] : state.builtinIds) appender(varData.varId); @@ -446,9 +452,9 @@ namespace Nz appender(varData.varId); }); - if (m_context.shader->GetStage() == ShaderStageType::Fragment) + if (stage == ShaderStageType::Fragment) state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft); - }*/ + } std::vector ret; MergeSections(ret, state.header); @@ -472,7 +478,7 @@ namespace Nz void SpirvWriter::AppendHeader() { - m_currentState->header.AppendRaw(SpvMagicNumber); //< Spir-V magic number + m_currentState->header.AppendRaw(SpirvMagicNumber); //< Spir-V magic number UInt32 version = (m_environment.spvMajorVersion << 16) | m_environment.spvMinorVersion << 8; m_currentState->header.AppendRaw(version); //< Spir-V version number (1.0 for compatibility)