From 16e2f5f8197c40e30f660c7e73a9be3c795b6d1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Tue, 1 Jun 2021 12:32:24 +0200 Subject: [PATCH] Shader: Add support for depth_write and early_fragment_tests attributes (+ FragDepth builtin) --- include/Nazara/Shader/Ast/Enums.hpp | 184 +++++++++++------- include/Nazara/Shader/Ast/Nodes.hpp | 2 + include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 22 ++- include/Nazara/Shader/LangWriter.hpp | 6 +- include/Nazara/Shader/SpirvAstVisitor.hpp | 1 + src/Nazara/Shader/Ast/AstCloner.cpp | 2 + src/Nazara/Shader/Ast/AstSerializer.cpp | 2 + src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 181 +++++++++++++---- src/Nazara/Shader/GlslWriter.cpp | 14 ++ src/Nazara/Shader/LangWriter.cpp | 62 +++++- src/Nazara/Shader/ShaderLangParser.cpp | 72 ++++++- src/Nazara/Shader/SpirvWriter.cpp | 44 ++++- 12 files changed, 456 insertions(+), 136 deletions(-) diff --git a/include/Nazara/Shader/Ast/Enums.hpp b/include/Nazara/Shader/Ast/Enums.hpp index 4717c2ab9..f719afa59 100644 --- a/include/Nazara/Shader/Ast/Enums.hpp +++ b/include/Nazara/Shader/Ast/Enums.hpp @@ -8,97 +8,131 @@ #define NAZARA_SHADER_AST_ENUMS_HPP #include +#include -namespace Nz::ShaderAst +namespace Nz { - enum class AssignType + namespace ShaderAst { - Simple //< = + enum class AssignType + { + Simple //< = + }; + + enum class AttributeType + { + Binding, //< Binding (external var only) - has argument index + Builtin, //< Builtin (struct member only) - has argument type + DepthWrite, //< Depth write mode (function only) - has argument type + EarlyFragmentTests, //< Entry point (function only) - has argument on/off + Entry, //< Entry point (function only) - has argument type + Layout, //< Struct layout (struct only) - has argument style + Location, //< Location (struct member only) - has argument index + Option, //< Conditional compilation option - has argument expr + }; + + enum class BinaryType + { + Add, //< + + Subtract, //< - + Multiply, //< * + Divide, //< / + + CompEq, //< == + CompGe, //< >= + CompGt, //< > + CompLe, //< <= + CompLt, //< < + CompNe //< <= + }; + + enum class BuiltinEntry + { + FragCoord = 1, // gl_FragCoord + FragDepth = 2, // gl_FragDepth + VertexPosition = 0, // gl_Position + }; + + enum class DepthWriteMode + { + Greater, + Less, + Replace, + Unchanged, + }; + + enum class ExpressionCategory + { + LValue, + RValue + }; + + enum class FunctionFlag + { + DoesDiscard, + DoesWriteFragDepth, + + Max = DoesWriteFragDepth + }; + } + + template<> + struct EnumAsFlags + { + static constexpr ShaderAst::FunctionFlag max = ShaderAst::FunctionFlag::Max; }; - enum class AttributeType + namespace ShaderAst { - Binding, //< Binding (external var only) - has argument index - Builtin, //< Builtin (struct member only) - has argument type - Entry, //< Entry point (function only) - has argument type - Layout, //< Struct layout (struct only) - has argument style - Location, //< Location (struct member only) - has argument index - Option, //< Conditional compilation option - has argument expr - }; + using FunctionFlags = Flags; - enum class BinaryType - { - Add, //< + - Subtract, //< - - Multiply, //< * - Divide, //< / + enum class IntrinsicType + { + CrossProduct = 0, + DotProduct = 1, + Length = 3, + Max = 4, + Min = 5, + SampleTexture = 2, + }; - CompEq, //< == - CompGe, //< >= - CompGt, //< > - CompLe, //< <= - CompLt, //< < - CompNe //< <= - }; + enum class MemoryLayout + { + Std140 + }; - enum class BuiltinEntry - { - FragCoord = 1, // gl_FragCoord - VertexPosition = 0, // gl_Position - }; - - enum class ExpressionCategory - { - LValue, - RValue - }; - - enum class IntrinsicType - { - CrossProduct = 0, - DotProduct = 1, - Length = 3, - Max = 4, - Min = 5, - SampleTexture = 2, - }; - - enum class MemoryLayout - { - Std140 - }; - - enum class NodeType - { - None = -1, + enum class NodeType + { + None = -1, #define NAZARA_SHADERAST_NODE(Node) Node, #define NAZARA_SHADERAST_STATEMENT_LAST(Node) Node, Max = Node #include - }; + }; - enum class PrimitiveType - { - Boolean, //< bool - Float32, //< f32 - Int32, //< i32 - UInt32, //< ui32 - }; + enum class PrimitiveType + { + Boolean, //< bool + Float32, //< f32 + Int32, //< i32 + UInt32, //< ui32 + }; - enum class SwizzleComponent - { - First, - Second, - Third, - Fourth - }; + enum class SwizzleComponent + { + First, + Second, + Third, + Fourth + }; - enum class UnaryType - { - LogicalNot, //< !v - Minus, //< -v - Plus, //< +v - }; + enum class UnaryType + { + LogicalNot, //< !v + Minus, //< -v + Plus, //< +v + }; + } } #endif // NAZARA_SHADER_ENUMS_HPP diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index f44c74693..28ba5cd0e 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -272,6 +272,8 @@ namespace Nz::ShaderAst ExpressionType type; }; + std::optional depthWrite; + std::optional earlyFragmentTests; std::optional entryStage; std::optional funcIndex; std::optional varIndex; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 86bf41b2c..54f67ec7c 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -8,6 +8,7 @@ #define NAZARA_SHADERAST_TRANSFORMVISITOR_HPP #include +#include #include #include #include @@ -39,6 +40,7 @@ namespace Nz::ShaderAst }; private: + struct FunctionData; struct Identifier; const ExpressionType& CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices); @@ -65,6 +67,7 @@ namespace Nz::ShaderAst StatementPtr Clone(DeclareOptionStatement& node) override; StatementPtr Clone(DeclareStructStatement& node) override; StatementPtr Clone(DeclareVariableStatement& node) override; + StatementPtr Clone(DiscardStatement& node) override; StatementPtr Clone(ExpressionStatement& node) override; StatementPtr Clone(MultiStatement& node) override; @@ -78,12 +81,18 @@ namespace Nz::ShaderAst void PushScope(); void PopScope(); - std::size_t RegisterFunction(DeclareFunctionStatement* funcDecl); + std::size_t DeclareFunction(DeclareFunctionStatement* funcDecl); + + void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen); + + FunctionData& RegisterFunction(std::size_t functionIndex); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterOption(std::string name, ExpressionType type); std::size_t RegisterStruct(std::string name, StructDescription description); std::size_t RegisterVariable(std::string name, ExpressionType type); + void ResolveFunctions(); + std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const StructType& structType); @@ -95,6 +104,14 @@ namespace Nz::ShaderAst void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration); void Validate(IntrinsicExpression& node); + struct FunctionData + { + Bitset<> calledByFunctions; + DeclareFunctionStatement* node; + FunctionFlags flags; + bool defined = false; + }; + struct Identifier { enum class Type @@ -112,9 +129,8 @@ namespace Nz::ShaderAst Type type; }; - std::unordered_map> m_functionDeclarations; std::vector m_identifiersInScope; - std::vector m_functions; + std::vector m_functions; std::vector m_intrinsics; std::vector m_options; std::vector m_structs; diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index dbc0644de..fac98f803 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -39,6 +39,8 @@ namespace Nz private: struct BindingAttribute; struct BuiltinAttribute; + struct DepthWriteAttribute; + struct EarlyFragmentTestsAttribute; struct EntryAttribute; struct LayoutAttribute; struct LocationAttribute; @@ -55,8 +57,10 @@ namespace Nz template void Append(const T& param); template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); template void AppendAttributes(bool appendLine, Args&&... params); - void AppendAttribute(BindingAttribute builtin); + void AppendAttribute(BindingAttribute binding); void AppendAttribute(BuiltinAttribute builtin); + void AppendAttribute(DepthWriteAttribute depthWrite); + void AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests); void AppendAttribute(EntryAttribute entry); void AppendAttribute(LayoutAttribute layout); void AppendAttribute(LocationAttribute location); diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 2a5531236..6b4c123bb 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -95,6 +95,7 @@ namespace Nz std::optional outputStructTypeId; std::vector inputs; std::vector outputs; + std::vector executionModes; }; struct FuncData diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index ad37f7e0b..aa016a304 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -74,6 +74,8 @@ namespace Nz::ShaderAst StatementPtr AstCloner::Clone(DeclareFunctionStatement& node) { auto clone = std::make_unique(); + clone->depthWrite = node.depthWrite; + clone->earlyFragmentTests = node.earlyFragmentTests; clone->entryStage = node.entryStage; clone->funcIndex = node.funcIndex; clone->name = node.name; diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 654e99ea8..bcf189ee9 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -228,6 +228,8 @@ namespace Nz::ShaderAst { Value(node.name); Type(node.returnType); + OptEnum(node.depthWrite); + OptVal(node.earlyFragmentTests); OptEnum(node.entryStage); OptVal(node.funcIndex); Value(node.optionName); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index d5e3422ce..d9cbef23b 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -30,10 +30,19 @@ namespace Nz::ShaderAst struct SanitizeVisitor::Context { + struct FunctionData + { + std::optional stageType; + Bitset<> calledFunctions; + DeclareFunctionStatement* statement; + FunctionFlags flags; + }; + Options options; std::array entryFunctions = {}; std::unordered_set declaredExternalVar; std::unordered_set usedBindingIndexes; + FunctionData* currentFunction = nullptr; }; StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& nodePtr, const Options& options, std::string* error) @@ -57,18 +66,15 @@ namespace Nz::ShaderAst // Collect function name and their types if (nodePtr->GetType() == NodeType::MultiStatement) { - std::size_t functionIndex = 0; - const MultiStatement& multiStatement = static_cast(*nodePtr); - for (const auto& statementPtr : multiStatement.statements) + for (auto& statementPtr : multiStatement.statements) { if (statementPtr->GetType() == NodeType::DeclareFunctionStatement) - { - const DeclareFunctionStatement& funcDeclaration = static_cast(*statementPtr); - m_functionDeclarations.emplace(funcDeclaration.name, std::make_pair(&funcDeclaration, functionIndex++)); - } + DeclareFunction(static_cast(statementPtr.get())); } } + else if (nodePtr->GetType() == NodeType::DeclareFunctionStatement) + DeclareFunction(static_cast(nodePtr.get())); try { @@ -81,6 +87,8 @@ namespace Nz::ShaderAst *error = err.errMsg; } + + ResolveFunctions(); } PopScope(); @@ -380,7 +388,8 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node) { - constexpr std::size_t NoFunction = std::numeric_limits::max(); + if (!m_context->currentFunction) + throw AstError{ "function calls must happen inside a function" }; auto clone = std::make_unique(); @@ -388,7 +397,7 @@ namespace Nz::ShaderAst for (std::size_t i = 0; i < node.parameters.size(); ++i) clone->parameters.push_back(CloneExpression(node.parameters[i])); - const DeclareFunctionStatement* referenceFunctionDeclaration; + std::size_t targetFuncIndex; if (std::holds_alternative(node.targetFunction)) { const std::string& functionName = std::get(node.targetFunction); @@ -417,28 +426,27 @@ namespace Nz::ShaderAst throw AstError{ "function expected" }; clone->targetFunction = identifier->index; - referenceFunctionDeclaration = m_functions[identifier->index]; + targetFuncIndex = identifier->index; } } else { // Identifier not found, maybe the function is declared later - auto it = m_functionDeclarations.find(functionName); - if (it == m_functionDeclarations.end()) + auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; }); + if (it == m_functions.end()) throw AstError{ "function " + functionName + " does not exist" }; - clone->targetFunction = it->second.second; + targetFuncIndex = std::distance(m_functions.begin(), it); - referenceFunctionDeclaration = it->second.first; + clone->targetFunction = targetFuncIndex; } } else - { - std::size_t funcIndex = std::get(node.targetFunction); - referenceFunctionDeclaration = m_functions[funcIndex]; - } + targetFuncIndex = std::get(node.targetFunction); - Validate(*clone, referenceFunctionDeclaration); + m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex); + + Validate(*clone, m_functions[targetFuncIndex].node); return clone; } @@ -447,6 +455,19 @@ namespace Nz::ShaderAst { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + clone->cachedExpressionType = clone->targetType; + clone->targetType = ResolveType(clone->targetType); + + //FIXME: Make proper rules + if (IsMatrixType(clone->targetType) && clone->expressions.front()) + { + const ExpressionType& exprType = GetExpressionType(*clone->expressions.front()); + if (IsMatrixType(exprType) && !clone->expressions[1]) + { + return clone; + } + } + auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t { if (IsVectorType(exprType)) @@ -476,9 +497,6 @@ namespace Nz::ShaderAst if (componentCount != requiredComponents) throw AstError{ "component count doesn't match required component count" }; - clone->targetType = ResolveType(clone->targetType); - clone->cachedExpressionType = clone->targetType; - return clone; } @@ -732,9 +750,20 @@ namespace Nz::ShaderAst if (node.parameters.size() > 1) throw AstError{ "entry functions can either take one struct parameter or no parameter" }; + + if (stageType != ShaderStageType::Fragment) + { + if (node.depthWrite.has_value()) + throw AstError{ "only fragment entry-points can have the depth_write attribute" }; + + if (node.earlyFragmentTests.has_value()) + throw AstError{ "only functions with entry(frag) attribute can have the early_fragments_tests attribute" }; + } } auto clone = std::make_unique(); + clone->depthWrite = node.depthWrite; + clone->earlyFragmentTests = node.earlyFragmentTests; clone->entryStage = node.entryStage; clone->name = node.name; clone->optionName = node.optionName; @@ -743,6 +772,11 @@ namespace Nz::ShaderAst SanitizeIdentifier(clone->name); + Context::FunctionData tempFuncData; + tempFuncData.stageType = node.entryStage; + + m_context->currentFunction = &tempFuncData; + PushScope(); { for (auto& parameter : clone->parameters) @@ -761,6 +795,14 @@ namespace Nz::ShaderAst } PopScope(); + m_context->currentFunction = nullptr; + + if (clone->earlyFragmentTests.has_value() && *clone->earlyFragmentTests) + { + //TODO: warning and disable early fragment tests + throw AstError{ "discard is not compatible with early fragment tests" }; + } + if (!clone->optionName.empty()) { const Identifier* identifier = FindIdentifier(node.optionName); @@ -775,7 +817,23 @@ namespace Nz::ShaderAst return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone)); } - clone->funcIndex = RegisterFunction(clone.get()); + auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node == &node; }); + assert(it != m_functions.end()); + assert(!it->defined); + + std::size_t funcIndex = std::distance(m_functions.begin(), it); + + clone->funcIndex = funcIndex; + + auto& funcData = RegisterFunction(funcIndex); + funcData.flags = tempFuncData.flags; + + for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) + { + assert(i < m_functions.size()); + auto& targetFunc = m_functions[i]; + targetFunc.calledByFunctions.UnboundedSet(funcIndex); + } return clone; } @@ -840,6 +898,16 @@ namespace Nz::ShaderAst return clone; } + StatementPtr SanitizeVisitor::Clone(DiscardStatement& node) + { + if (!m_context->currentFunction) + throw AstError{ "discard can only be used inside a function" }; + + m_context->currentFunction->flags |= FunctionFlag::DoesDiscard; + + return AstCloner::Clone(node); + } + StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node) { MandatoryExpr(node.expression); @@ -889,34 +957,57 @@ namespace Nz::ShaderAst m_scopeSizes.pop_back(); } - std::size_t SanitizeVisitor::RegisterFunction(DeclareFunctionStatement* funcDecl) + std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement* funcDecl) { - if (auto* identifier = FindIdentifier(funcDecl->name)) + std::size_t functionIndex = m_functions.size(); + auto& funcData = m_functions.emplace_back(); + funcData.node = funcDecl; + + return functionIndex; + } + + void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) + { + assert(funcIndex < m_functions.size()); + auto& funcData = m_functions[funcIndex]; + assert(funcData.defined); + + funcData.flags |= flags; + + for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) + PropagateFunctionFlags(i, funcData.flags, seen); + } + + auto SanitizeVisitor::RegisterFunction(std::size_t functionIndex) -> FunctionData& + { + assert(m_functions.size() >= functionIndex); + auto& funcData = m_functions[functionIndex]; + assert(!funcData.defined); + funcData.defined = true; + + if (auto* identifier = FindIdentifier(funcData.node->name)) { bool duplicate = true; // Functions cannot be declared twice, except for entry ones if their stages are different - if (funcDecl->entryStage && identifier->type == Identifier::Type::Function) + if (funcData.node->entryStage && identifier->type == Identifier::Type::Function) { auto& otherFunction = m_functions[identifier->index]; - if (funcDecl->entryStage != otherFunction->entryStage) + if (funcData.node->entryStage != otherFunction.node->entryStage) duplicate = false; } if (duplicate) - throw AstError{ funcDecl->name + " is already used" }; + throw AstError{ funcData.node->name + " is already used" }; } - std::size_t functionIndex = m_functions.size(); - m_functions.push_back(funcDecl); - m_identifiersInScope.push_back({ - funcDecl->name, + funcData.node->name, functionIndex, Identifier::Type::Function - }); + }); - return functionIndex; + return funcData; } std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type) @@ -965,7 +1056,7 @@ namespace Nz::ShaderAst std::move(name), structIndex, Identifier::Type::Struct - }); + }); return structIndex; } @@ -988,6 +1079,26 @@ namespace Nz::ShaderAst return varIndex; } + void SanitizeVisitor::ResolveFunctions() + { + // Once every function is known, we can propagate flags + + Bitset<> seen; + for (std::size_t funcIndex = 0; funcIndex < m_functions.size(); ++funcIndex) + { + auto& funcData = m_functions[funcIndex]; + + PropagateFunctionFlags(funcIndex, funcData.flags, seen); + seen.Clear(); + } + + for (const FunctionData& funcData : m_functions) + { + if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage && *funcData.node->entryStage != ShaderStageType::Fragment) + throw AstError{ "discard can only be used in the fragment stage" }; + } + } + std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType) { return std::visit([&](auto&& arg) -> std::size_t diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 9d793a628..d4829b8f2 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -110,6 +110,7 @@ namespace Nz std::unordered_map s_builtinMapping = { { ShaderAst::BuiltinEntry::FragCoord, { "gl_FragCoord", ShaderStageType::Fragment } }, + { ShaderAst::BuiltinEntry::FragDepth, { "gl_FragDepth", ShaderStageType::Fragment } }, { ShaderAst::BuiltinEntry::VertexPosition, { "gl_Position", ShaderStageType::Vertex } } }; } @@ -206,6 +207,10 @@ namespace Nz Append("gl_FragCoord"); break; + case ShaderAst::BuiltinEntry::FragDepth: + Append("gl_FragDepth"); + break; + case ShaderAst::BuiltinEntry::VertexPosition: Append("gl_Position"); break; @@ -500,6 +505,15 @@ namespace Nz void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node) { + if (node.entryStage == ShaderStageType::Fragment && node.earlyFragmentTests && *node.earlyFragmentTests) + { + if ((m_environment.glES && m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1) || (!m_environment.glES && m_environment.glMajorVersion >= 4 && m_environment.glMinorVersion >= 2) || m_environment.extCallback("GL_ARB_shader_image_load_store")) + { + AppendLine("layout(early_fragment_tests) in;"); + AppendLine(); + } + } + HandleInOut(); AppendLine("void main()"); EnterScope(); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 0d0b8c237..6b4182787 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -41,6 +41,20 @@ namespace Nz inline bool HasValue() const { return builtin.has_value(); } }; + struct LangWriter::DepthWriteAttribute + { + std::optional writeMode; + + inline bool HasValue() const { return writeMode.has_value(); } + }; + + struct LangWriter::EarlyFragmentTestsAttribute + { + std::optional earlyFragmentTests; + + inline bool HasValue() const { return earlyFragmentTests.has_value(); } + }; + struct LangWriter::EntryAttribute { std::optional stageType; @@ -216,12 +230,12 @@ namespace Nz Append(" "); } - void LangWriter::AppendAttribute(BindingAttribute builtin) + void LangWriter::AppendAttribute(BindingAttribute binding) { - if (!builtin.HasValue()) + if (!binding.HasValue()) return; - Append("binding(", *builtin.bindingIndex, ")"); + Append("binding(", *binding.bindingIndex, ")"); } void LangWriter::AppendAttribute(BuiltinAttribute builtin) @@ -235,11 +249,51 @@ namespace Nz Append("builtin(fragcoord)"); break; + case ShaderAst::BuiltinEntry::FragDepth: + Append("builtin(fragdepth)"); + break; + case ShaderAst::BuiltinEntry::VertexPosition: Append("builtin(position)"); break; } } + + void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite) + { + if (!depthWrite.HasValue()) + return; + + switch (*depthWrite.writeMode) + { + case ShaderAst::DepthWriteMode::Greater: + Append("depth_write(greater)"); + break; + + case ShaderAst::DepthWriteMode::Less: + Append("depth_write(less)"); + break; + + case ShaderAst::DepthWriteMode::Replace: + Append("depth_write(replace)"); + break; + + case ShaderAst::DepthWriteMode::Unchanged: + Append("depth_write(unchanged)"); + break; + } + } + + void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests) + { + if (!earlyFragmentTests.HasValue()) + return; + + if (*earlyFragmentTests.earlyFragmentTests) + Append("early_fragment_tests(on)"); + else + Append("early_fragment_tests(off)"); + } void LangWriter::AppendAttribute(EntryAttribute entry) { @@ -553,7 +607,7 @@ namespace Nz std::optional varIndexOpt = node.varIndex; - AppendAttributes(true, EntryAttribute{ node.entryStage }); + AppendAttributes(true, EntryAttribute{ node.entryStage }, EarlyFragmentTestsAttribute{ node.earlyFragmentTests }, DepthWriteAttribute{ node.depthWrite }); Append("fn ", node.name, "("); for (std::size_t i = 0; i < node.parameters.size(); ++i) { diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index adf60a972..470cbed8b 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -11,7 +11,14 @@ namespace Nz::ShaderLang { namespace - { + { + std::unordered_map s_depthWriteModes = { + { "greater", ShaderAst::DepthWriteMode::Greater }, + { "less", ShaderAst::DepthWriteMode::Less }, + { "replace", ShaderAst::DepthWriteMode::Replace }, + { "unchanged", ShaderAst::DepthWriteMode::Unchanged }, + }; + std::unordered_map s_identifierToBasicType = { { "bool", ShaderAst::PrimitiveType::Boolean }, { "i32", ShaderAst::PrimitiveType::Int32 }, @@ -20,12 +27,14 @@ namespace Nz::ShaderLang }; std::unordered_map s_identifierToAttributeType = { - { "binding", ShaderAst::AttributeType::Binding }, - { "builtin", ShaderAst::AttributeType::Builtin }, - { "entry", ShaderAst::AttributeType::Entry }, - { "layout", ShaderAst::AttributeType::Layout }, - { "location", ShaderAst::AttributeType::Location }, - { "opt", ShaderAst::AttributeType::Option }, + { "binding", ShaderAst::AttributeType::Binding }, + { "builtin", ShaderAst::AttributeType::Builtin }, + { "depth_write", ShaderAst::AttributeType::DepthWrite }, + { "early_fragment_tests", ShaderAst::AttributeType::EarlyFragmentTests }, + { "entry", ShaderAst::AttributeType::Entry }, + { "layout", ShaderAst::AttributeType::Layout }, + { "location", ShaderAst::AttributeType::Location }, + { "opt", ShaderAst::AttributeType::Option }, }; std::unordered_map s_entryPoints = { @@ -35,6 +44,7 @@ namespace Nz::ShaderLang std::unordered_map s_builtinMapping = { { "fragcoord", ShaderAst::BuiltinEntry::FragCoord }, + { "fragdepth", ShaderAst::BuiltinEntry::FragDepth }, { "position", ShaderAst::BuiltinEntry::VertexPosition } }; @@ -483,11 +493,55 @@ namespace Nz::ShaderLang for (const auto& [attributeType, arg] : attributes) { switch (attributeType) - { + { + case ShaderAst::AttributeType::DepthWrite: + { + if (func->depthWrite) + throw AttributeError{ "attribute depth_write can only be present once" }; + + if (!std::holds_alternative(arg)) + throw AttributeError{ "attribute entry requires a string parameter" }; + + const std::string& argStr = std::get(arg); + + auto it = s_depthWriteModes.find(argStr); + if (it == s_depthWriteModes.end()) + throw AttributeError{ ("invalid parameter " + argStr + " for depth_write attribute").c_str() }; + + func->depthWrite = it->second; + break; + } + + case ShaderAst::AttributeType::EarlyFragmentTests: + { + if (func->earlyFragmentTests) + throw AttributeError{ "attribute early_fragment_tests can only be present once" }; + + if (std::holds_alternative(arg)) + { + const std::string& argStr = std::get(arg); + if (argStr == "true" || argStr == "on") + func->earlyFragmentTests = true; + else if (argStr == "false" || argStr == "off") + func->earlyFragmentTests = false; + else + throw AttributeError{ "expected boolean value (got " + argStr + ")" }; + } + else if (std::holds_alternative(arg)) + { + // No parameter, default to true + func->earlyFragmentTests = true; + } + else + throw AttributeError{ "unexpected value for early_fragment_tests" }; + + break; + } + case ShaderAst::AttributeType::Entry: { if (func->entryStage) - throw AttributeError{ "attribute entry must be present once" }; + throw AttributeError{ "attribute entry can only be present once" }; if (!std::holds_alternative(arg)) throw AttributeError{ "attribute entry requires a string parameter" }; diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 4406719d5..24784c551 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -37,7 +37,8 @@ namespace Nz std::unordered_map s_builtinMapping = { { ShaderAst::BuiltinEntry::FragCoord, { "FragmentCoordinates", ShaderStageType::Fragment, SpirvBuiltIn::FragCoord } }, - { ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } } + { ShaderAst::BuiltinEntry::FragDepth, { "FragmentDepth", ShaderStageType::Fragment, SpirvBuiltIn::FragDepth } }, + { ShaderAst::BuiltinEntry::VertexPosition, { "VertexPosition", ShaderStageType::Vertex, SpirvBuiltIn::Position } } }; class PreVisitor : public ShaderAst::AstRecursiveVisitor @@ -181,6 +182,28 @@ namespace Nz { using EntryPoint = SpirvAstVisitor::EntryPoint; + std::vector executionModes; + + if (*entryPointType == ShaderStageType::Fragment) + { + executionModes.push_back(SpirvExecutionMode::OriginUpperLeft); + if (node.earlyFragmentTests && *node.earlyFragmentTests) + executionModes.push_back(SpirvExecutionMode::EarlyFragmentTests); + + if (node.depthWrite) + { + executionModes.push_back(SpirvExecutionMode::DepthReplacing); + + switch (*node.depthWrite) + { + case ShaderAst::DepthWriteMode::Replace: break; + case ShaderAst::DepthWriteMode::Greater: executionModes.push_back(SpirvExecutionMode::DepthGreater); break; + case ShaderAst::DepthWriteMode::Less: executionModes.push_back(SpirvExecutionMode::DepthLess); break; + case ShaderAst::DepthWriteMode::Unchanged: executionModes.push_back(SpirvExecutionMode::DepthUnchanged); break; + } + } + } + funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{})); funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, {})); @@ -248,7 +271,8 @@ namespace Nz inputStruct, outputStructId, std::move(inputs), - std::move(outputs) + std::move(outputs), + std::move(executionModes) }; } @@ -522,7 +546,6 @@ namespace Nz m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450); - std::optional fragmentFuncId; for (auto& func : m_currentState->funcs) { m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name); @@ -559,15 +582,18 @@ namespace Nz for (const auto& output : entryPointData.outputs) appender(output.varId); }); - - if (entryPointData.stageType == ShaderStageType::Fragment) - fragmentFuncId = func.funcId; } } - if (fragmentFuncId) - m_currentState->header.Append(SpirvOp::OpExecutionMode, *fragmentFuncId, SpirvExecutionMode::OriginUpperLeft); - + // Write execution modes + for (auto& func : m_currentState->funcs) + { + if (func.entryPointData) + { + for (SpirvExecutionMode executionMode : func.entryPointData->executionModes) + m_currentState->header.Append(SpirvOp::OpExecutionMode, func.funcId, executionMode); + } + } } SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)