diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index 1d4796379..a0990be8b 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -498,6 +498,9 @@ namespace Nz::ShaderAst if (!Compare(lhs.entryStage, rhs.entryStage)) return false; + if (!Compare(lhs.isExported, rhs.isExported)) + return false; + if (!Compare(lhs.name, rhs.name)) return false; diff --git a/include/Nazara/Shader/Ast/AstExportVisitor.hpp b/include/Nazara/Shader/Ast/AstExportVisitor.hpp index 49db1d5e8..baeab0701 100644 --- a/include/Nazara/Shader/Ast/AstExportVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstExportVisitor.hpp @@ -31,10 +31,12 @@ namespace Nz::ShaderAst struct Callbacks { + std::function onExportedFunc; std::function onExportedStruct; }; private: + void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareStructStatement& node) override; const Callbacks* m_callbacks; diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp index 06962beac..34ef0139f 100644 --- a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.hpp @@ -27,6 +27,7 @@ namespace Nz::ShaderAst inline const UsageSet& GetUsage() const; + inline void MarkFunctionAsUsed(std::size_t funcIndex); inline void MarkStructAsUsed(std::size_t structIndex); inline void Process(Statement& statement); diff --git a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl index f1708fd41..3e2a7a8c0 100644 --- a/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl +++ b/include/Nazara/Shader/Ast/DependencyCheckerVisitor.inl @@ -12,6 +12,11 @@ namespace Nz::ShaderAst return m_resolvedUsage; } + inline void DependencyCheckerVisitor::MarkFunctionAsUsed(std::size_t funcIndex) + { + m_globalUsage.usedFunctions.UnboundedSet(funcIndex); + } + inline void DependencyCheckerVisitor::MarkStructAsUsed(std::size_t structIndex) { m_globalUsage.usedStructs.UnboundedSet(structIndex); diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 61bc382b2..3ffdb5c1a 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -335,6 +335,7 @@ namespace Nz::ShaderAst ExpressionValue entryStage; ExpressionValue returnType; ExpressionValue earlyFragmentTests; + ExpressionValue isExported; }; struct NAZARA_SHADER_API DeclareOptionStatement : Statement diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index facdefeb4..711c74d81 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -74,7 +74,7 @@ namespace Nz template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); void AppendComment(const std::string& section); void AppendCommentSection(const std::string& section); - void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false); + void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, const std::string& nameOverride, bool forward = false); void AppendHeader(); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 5dc57be47..c522165c8 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -124,6 +124,7 @@ namespace Nz::ShaderAst clone->earlyFragmentTests = Clone(node.earlyFragmentTests); clone->entryStage = Clone(node.entryStage); clone->funcIndex = node.funcIndex; + clone->isExported = Clone(node.isExported); clone->name = node.name; clone->returnType = Clone(node.returnType); diff --git a/src/Nazara/Shader/Ast/AstExportVisitor.cpp b/src/Nazara/Shader/Ast/AstExportVisitor.cpp index 3d79f5f35..6a0643860 100644 --- a/src/Nazara/Shader/Ast/AstExportVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstExportVisitor.cpp @@ -14,6 +14,15 @@ namespace Nz::ShaderAst statement.Visit(*this); } + void AstExportVisitor::Visit(DeclareFunctionStatement& node) + { + if (!node.isExported.HasValue() || !node.isExported.GetResultingValue()) + return; + + if (m_callbacks->onExportedFunc) + m_callbacks->onExportedFunc(node); + } + void AstExportVisitor::Visit(DeclareStructStatement& node) { if (!node.isExported.HasValue() || !node.isExported.GetResultingValue()) diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 94611d465..1bd7b4962 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -260,6 +260,7 @@ namespace Nz::ShaderAst ExprValue(node.depthWrite); ExprValue(node.earlyFragmentTests); ExprValue(node.entryStage); + ExprValue(node.isExported); OptVal(node.funcIndex); Container(node.parameters); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 79b5aed88..c1b7aaf64 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -496,16 +496,28 @@ namespace Nz::ShaderAst { ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction)); const ExpressionType& targetExprType = GetExpressionType(*targetExpr); + const ExpressionType& resolvedType = ResolveAlias(targetExprType); - if (IsFunctionType(targetExprType)) + if (IsFunctionType(resolvedType)) { if (!m_context->currentFunction) throw AstError{ "function calls must happen inside a function" }; - if (targetExpr->GetType() != NodeType::FunctionExpression) - throw AstError{ "expected function expression" }; + std::size_t targetFuncIndex; + if (targetExpr->GetType() == NodeType::FunctionExpression) + targetFuncIndex = static_cast(*targetExpr).funcId; + else if (targetExpr->GetType() == NodeType::AliasValueExpression) + { + const auto& alias = static_cast(*targetExpr); - std::size_t targetFuncIndex = static_cast(*targetExpr).funcId; + const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId)); + if (targetIdentifier->category != IdentifierCategory::Function) + throw AstError{ "expected function expression" }; + + targetFuncIndex = targetIdentifier->index; + } + else + throw AstError{ "expected function expression" }; auto clone = std::make_unique(); clone->targetFunction = std::move(targetExpr); @@ -520,7 +532,7 @@ namespace Nz::ShaderAst return clone; } - else if (IsIntrinsicFunctionType(targetExprType)) + else if (IsIntrinsicFunctionType(resolvedType)) { if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression) throw AstError{ "expected intrinsic function expression" }; @@ -538,9 +550,9 @@ namespace Nz::ShaderAst return intrinsic; } - else if (IsMethodType(targetExprType)) + else if (IsMethodType(resolvedType)) { - const MethodType& methodType = std::get(targetExprType); + const MethodType& methodType = std::get(resolvedType); std::vector parameters; parameters.reserve(node.parameters.size() + 1); @@ -987,6 +999,9 @@ namespace Nz::ShaderAst if (node.entryStage.HasValue()) clone->entryStage = ComputeExprValue(node.entryStage); + if (node.isExported.HasValue()) + clone->isExported = ComputeExprValue(node.isExported); + if (clone->entryStage.HasValue()) { ShaderStageType stageType = clone->entryStage.GetResultingValue(); @@ -1468,7 +1483,7 @@ namespace Nz::ShaderAst hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size()); hasher.End(); - std::string identifier = "__" + hasher.End().ToHex().substr(0, 8); + std::string identifier = "_" + hasher.End().ToHex().substr(0, 8); // Load new module auto moduleEnvironment = std::make_shared(); @@ -1519,6 +1534,19 @@ namespace Nz::ShaderAst std::vector aliasStatements; AstExportVisitor::Callbacks callbacks; + callbacks.onExportedFunc = [&](DeclareFunctionStatement& node) + { + assert(node.funcIndex); + + moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex); + + if (!exportedSet.usedFunctions.UnboundedTest(*node.funcIndex)) + { + exportedSet.usedFunctions.UnboundedSet(*node.funcIndex); + aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.name, ShaderBuilder::Function(*node.funcIndex))); + } + }; + callbacks.onExportedStruct = [&](DeclareStructStatement& node) { assert(node.structIndex); @@ -2475,13 +2503,18 @@ namespace Nz::ShaderAst std::size_t structIndex = ResolveStruct(exprType); node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex); } + else if (IsFunctionType(exprType)) + { + std::size_t funcIndex = std::get(exprType).funcIndex; + node.aliasIndex = RegisterAlias(node.name, { funcIndex, IdentifierCategory::Function }, node.aliasIndex); + } else if (IsAliasType(exprType)) { const AliasType& alias = std::get(exprType); node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex); } else - throw AstError{ "for now, only structs can be aliased" }; + throw AstError{ "for now, only aliases, functions and structs can be aliased" }; } void SanitizeVisitor::Validate(WhileStatement& node) @@ -2659,10 +2692,22 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(CallFunctionExpression& node) { - const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction); - assert(std::holds_alternative(targetFuncType)); + std::size_t targetFuncIndex; + if (node.targetFunction->GetType() == NodeType::FunctionExpression) + targetFuncIndex = static_cast(*node.targetFunction).funcId; + else if (node.targetFunction->GetType() == NodeType::AliasValueExpression) + { + const auto& alias = static_cast(*node.targetFunction); + + const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId)); + if (targetIdentifier->category != IdentifierCategory::Function) + throw AstError{ "expected function expression" }; + + targetFuncIndex = targetIdentifier->index; + } + else + throw AstError{ "expected function expression" }; - std::size_t targetFuncIndex = std::get(targetFuncType).funcIndex; auto& funcData = m_context->functions.Retrieve(targetFuncIndex); const DeclareFunctionStatement* referenceDeclaration = funcData.node; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index b274912af..0b180fec5 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -85,7 +85,7 @@ namespace Nz assert(node.funcIndex); assert(functions.find(node.funcIndex.value()) == functions.end()); FunctionData& funcData = functions[node.funcIndex.value()]; - funcData.name = node.name; + funcData.name = node.name + moduleSuffix; funcData.node = &node; currentFunction = &funcData; @@ -105,6 +105,7 @@ namespace Nz FunctionData* currentFunction = nullptr; std::optional selectedStage; + std::string moduleSuffix; std::unordered_map functions; ShaderAst::DeclareFunctionStatement* entryPoint = nullptr; }; @@ -201,8 +202,12 @@ namespace Nz state.previsitor.selectedStage = shaderStage; for (const auto& importedModule : targetModule->importedModules) + { + state.previsitor.moduleSuffix = importedModule.identifier; importedModule.module->rootNode->Visit(state.previsitor); + } + state.previsitor.moduleSuffix = {}; targetModule->rootNode->Visit(state.previsitor); if (!state.previsitor.entryPoint) @@ -216,7 +221,7 @@ namespace Nz for (const auto& importedModule : targetModule->importedModules) { - AppendComment("Module " + importedModule.module->metadata->moduleId.ToString()); + AppendComment("Module " + importedModule.module->metadata->moduleName); AppendLine(); m_currentState->moduleSuffix = importedModule.identifier; @@ -311,8 +316,7 @@ namespace Nz void GlslWriter::Append(const ShaderAst::FunctionType& functionType) { - const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionType.funcIndex).name; - Append(targetName); + throw std::runtime_error("unexpected FunctionType"); } void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/) @@ -475,9 +479,9 @@ namespace Nz AppendLine(); } - void GlslWriter::AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward) + void GlslWriter::AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, const std::string& nameOverride, bool forward) { - Append(node.returnType, " ", node.name, "("); + Append(node.returnType, " ", nameOverride, "("); bool first = true; for (const auto& parameter : node.parameters) @@ -970,8 +974,8 @@ namespace Nz void GlslWriter::Visit(ShaderAst::FunctionExpression& node) { - const std::string& targetName = Retrieve(m_currentState->previsitor.functions, node.funcId).name; - Append(targetName); + const auto& funcData = Retrieve(m_currentState->previsitor.functions, node.funcId); + Append(funcData.name); } void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node) @@ -1214,7 +1218,7 @@ namespace Nz hasPredeclaration = true; auto& targetFunc = Retrieve(m_currentState->previsitor.functions, i); - AppendFunctionDeclaration(*targetFunc.node, true); + AppendFunctionDeclaration(*targetFunc.node, targetFunc.name, true); m_currentState->declaredFunctions.UnboundedSet(i); } @@ -1232,7 +1236,7 @@ namespace Nz RegisterVariable(*parameter.varIndex, parameter.name); } - AppendFunctionDeclaration(node); + AppendFunctionDeclaration(node, funcData.name); EnterScope(); { AppendStatementList(node.statements); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 90c3ab4e6..df53ba896 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -194,8 +194,7 @@ namespace Nz void LangWriter::Append(const ShaderAst::FunctionType& functionType) { - const std::string& targetName = Retrieve(m_currentState->functions, functionType.funcIndex).name; - Append(targetName); + throw std::runtime_error("unexpected function type"); } void LangWriter::Append(const ShaderAst::IdentifierType& identifierType) @@ -971,7 +970,7 @@ namespace Nz void LangWriter::Visit(ShaderAst::FunctionExpression& node) { - Append(Retrieve(m_currentState->functions, node.funcId).name); + AppendIdentifier(m_currentState->functions, node.funcId); } void LangWriter::Visit(ShaderAst::IdentifierExpression& node) diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 778f9ba52..3edbad180 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -686,24 +686,28 @@ namespace Nz::ShaderLang ShaderAst::ExpressionValue condition; - for (auto&& [attributeType, arg] : attributes) + for (auto&& [attributeType, attributeParam] : attributes) { switch (attributeType) { case ShaderAst::AttributeType::Cond: - HandleUniqueAttribute("cond", condition, std::move(arg)); + HandleUniqueAttribute("cond", condition, std::move(attributeParam)); break; case ShaderAst::AttributeType::Entry: - HandleUniqueStringAttribute("entry", s_entryPoints, func->entryStage, std::move(arg)); + HandleUniqueStringAttribute("entry", s_entryPoints, func->entryStage, std::move(attributeParam)); + break; + + case ShaderAst::AttributeType::Export: + HandleUniqueAttribute("export", func->isExported, std::move(attributeParam), true); break; case ShaderAst::AttributeType::DepthWrite: - HandleUniqueStringAttribute("depth_write", s_depthWriteModes, func->depthWrite, std::move(arg)); + HandleUniqueStringAttribute("depth_write", s_depthWriteModes, func->depthWrite, std::move(attributeParam)); break; case ShaderAst::AttributeType::EarlyFragmentTests: - HandleUniqueAttribute("early_fragment_tests", func->earlyFragmentTests, std::move(arg), true); + HandleUniqueAttribute("early_fragment_tests", func->earlyFragmentTests, std::move(attributeParam), true); break; default: diff --git a/tests/Engine/Shader/ModuleTests.cpp b/tests/Engine/Shader/ModuleTests.cpp index 39270c1c5..4785165d0 100644 --- a/tests/Engine/Shader/ModuleTests.cpp +++ b/tests/Engine/Shader/ModuleTests.cpp @@ -32,6 +32,12 @@ struct Block data: Data } +[export] +fn GetDataValue(data: Data) -> f32 +{ + return data.value; +} + struct Unused {} [export] @@ -62,7 +68,7 @@ external fn main(input: InputData) -> OutputData { let output: OutputData; - output.value = block.data.value * input.value; + output.value = GetDataValue(block.data) * input.value; return output; } )"; @@ -78,24 +84,29 @@ fn main(input: InputData) -> OutputData shaderModule = SanitizeModule(*shaderModule, sanitizeOpt); ExpectGLSL(*shaderModule, R"( -// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4 +// Module SimpleModule -struct Data__181c45e9 +struct Data_181c45e9 { float value; }; -struct Block__181c45e9 +struct Block_181c45e9 { - Data__181c45e9 data; + Data_181c45e9 data; }; -struct InputData__181c45e9 +float GetDataValue_181c45e9(Data_181c45e9 data) +{ + return data.value; +} + +struct InputData_181c45e9 { float value; }; -struct OutputData__181c45e9 +struct OutputData_181c45e9 { float value; }; @@ -105,7 +116,7 @@ struct OutputData__181c45e9 layout(std140) uniform _NzBinding_block { - Data__181c45e9 data; + Data_181c45e9 data; } block; @@ -117,11 +128,11 @@ out float _NzOut_value; void main() { - InputData__181c45e9 input_; + InputData_181c45e9 input_; input_.value = _NzIn_value; - OutputData__181c45e9 output_; - output_.value = block.data.value * input_.value; + OutputData_181c45e9 output_; + output_.value = (GetDataValue_181c45e9(block.data)) * input_.value; _NzOut_value = output_.value; return; @@ -134,7 +145,7 @@ module; [nzsl_version("1.0")] [uuid("ad3aed6e-0619-4a26-b5ce-abc2ec0836c4")] -module __181c45e9 +module _181c45e9 { [layout(std140)] struct Data @@ -148,6 +159,11 @@ module __181c45e9 data: Data } + fn GetDataValue(data: Data) -> f32 + { + return data.value; + } + struct InputData { value: f32 @@ -159,33 +175,45 @@ module __181c45e9 } } -alias Block = __181c45e9.Block; +alias Block = _181c45e9.Block; -alias InputData = __181c45e9.InputData; +alias GetDataValue = _181c45e9.GetDataValue; -alias OutputData = __181c45e9.OutputData; +alias InputData = _181c45e9.InputData; + +alias OutputData = _181c45e9.OutputData; external { - [set(0), binding(0)] block: uniform[__181c45e9.Block] + [set(0), binding(0)] block: uniform[_181c45e9.Block] } [entry(frag)] fn main(input: InputData) -> OutputData { let output: OutputData; - output.value = block.data.value * input.value; + output.value = (GetDataValue(block.data)) * input.value; return output; } )"); ExpectSPIRV(*shaderModule, R"( OpFunction +OpFunctionParameter OpLabel +OpAccessChain +OpLoad +OpReturnValue +OpFunctionEnd +OpFunction +OpLabel +OpVariable OpVariable OpVariable OpAccessChain OpLoad +OpStore +OpFunctionCall OpAccessChain OpLoad OpFMul @@ -283,29 +311,29 @@ fn main(input: InputData) -> OutputData shaderModule = SanitizeModule(*shaderModule, sanitizeOpt); ExpectGLSL(*shaderModule, R"( -// Module ad3aed6e-0619-4a26-b5ce-abc2ec0836c4 +// Module Modules.Data -struct Data__181c45e9 +struct Data_181c45e9 { float value; }; -// Module 7a548506-89e6-4944-897f-4f695a8bca01 +// Module Modules.Block -struct Block__e528265d +struct Block_e528265d { - Data__181c45e9 data; + Data_181c45e9 data; }; -// Module e66c6e98-fc37-4390-a7e1-c81508ff8e49 +// Module Modules.InputOutput -struct InputData__26cce136 +struct InputData_26cce136 { float value; }; -struct OutputData__26cce136 +struct OutputData_26cce136 { float value; }; @@ -316,7 +344,7 @@ struct OutputData__26cce136 layout(std140) uniform _NzBinding_block { - Data__181c45e9 data; + Data_181c45e9 data; } block; @@ -328,10 +356,10 @@ out float _NzOut_value; void main() { - InputData__26cce136 input_; + InputData_26cce136 input_; input_.value = _NzIn_value; - OutputData__26cce136 output_; + OutputData_26cce136 output_; output_.value = block.data.value * input_.value; _NzOut_value = output_.value; @@ -345,7 +373,7 @@ module; [nzsl_version("1.0")] [uuid("ad3aed6e-0619-4a26-b5ce-abc2ec0836c4")] -module __181c45e9 +module _181c45e9 { [layout(std140)] struct Data @@ -356,9 +384,9 @@ module __181c45e9 } [nzsl_version("1.0")] [uuid("7a548506-89e6-4944-897f-4f695a8bca01")] -module __e528265d +module _e528265d { - alias Data = __181c45e9.Data; + alias Data = _181c45e9.Data; [layout(std140)] struct Block @@ -369,7 +397,7 @@ module __e528265d } [nzsl_version("1.0")] [uuid("e66c6e98-fc37-4390-a7e1-c81508ff8e49")] -module __26cce136 +module _26cce136 { struct InputData { @@ -382,15 +410,15 @@ module __26cce136 } } -alias Block = __e528265d.Block; +alias Block = _e528265d.Block; -alias InputData = __26cce136.InputData; +alias InputData = _26cce136.InputData; -alias OutputData = __26cce136.OutputData; +alias OutputData = _26cce136.OutputData; external { - [set(0), binding(0)] block: uniform[__e528265d.Block] + [set(0), binding(0)] block: uniform[_e528265d.Block] } [entry(frag)]