diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp new file mode 100644 index 000000000..1631b684b --- /dev/null +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADER_AST_ELIMINATEUNUSEDPASSVISITOR_HPP +#define NAZARA_SHADER_AST_ELIMINATEUNUSEDPASSVISITOR_HPP + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_SHADER_API EliminateUnusedPassVisitor : AstCloner + { + public: + EliminateUnusedPassVisitor() = default; + EliminateUnusedPassVisitor(const EliminateUnusedPassVisitor&) = delete; + EliminateUnusedPassVisitor(EliminateUnusedPassVisitor&&) = delete; + ~EliminateUnusedPassVisitor() = default; + + StatementPtr Process(Statement& statement); + + EliminateUnusedPassVisitor& operator=(const EliminateUnusedPassVisitor&) = delete; + EliminateUnusedPassVisitor& operator=(EliminateUnusedPassVisitor&&) = delete; + + private: + using AstCloner::Clone; + StatementPtr Clone(DeclareExternalStatement& node) override; + StatementPtr Clone(DeclareFunctionStatement& node) override; + StatementPtr Clone(DeclareStructStatement& node) override; + StatementPtr Clone(DeclareVariableStatement& node) override; + + bool IsFunctionUsed(std::size_t varIndex) const; + bool IsStructUsed(std::size_t varIndex) const; + bool IsVariableUsed(std::size_t varIndex) const; + + struct Context; + Context* m_context; + }; + + inline StatementPtr EliminateUnusedPass(Statement& ast); +} + +#include + +#endif // NAZARA_SHADER_AST_ELIMINATEUNUSEDPASSVISITOR_HPP diff --git a/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl new file mode 100644 index 000000000..08512f309 --- /dev/null +++ b/include/Nazara/Shader/Ast/EliminateUnusedPassVisitor.inl @@ -0,0 +1,17 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + inline StatementPtr EliminateUnusedPass(Statement& ast) + { + EliminateUnusedPassVisitor visitor; + return visitor.Process(ast); + } +} + +#include diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index ae7e73f1f..8a207cbb8 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -265,11 +265,11 @@ namespace Nz::ShaderAst ExpressionValue bindingIndex; ExpressionValue bindingSet; ExpressionValue type; + std::optional varIndex; std::string name; }; ExpressionValue bindingSet; - std::optional varIndex; std::vector externalVars; }; diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 51be1d4c3..15353aea5 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -91,8 +91,6 @@ namespace Nz::ShaderAst StatementPtr AstCloner::Clone(DeclareExternalStatement& node) { auto clone = std::make_unique(); - clone->varIndex = node.varIndex; - clone->bindingSet = Clone(node.bindingSet); clone->externalVars.reserve(node.externalVars.size()); @@ -100,6 +98,7 @@ namespace Nz::ShaderAst { auto& cloneVar = clone->externalVars.emplace_back(); cloneVar.name = var.name; + cloneVar.varIndex = var.varIndex; cloneVar.type = Clone(var.type); cloneVar.bindingIndex = Clone(var.bindingIndex); cloneVar.bindingSet = Clone(var.bindingSet); diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index e81dd04c5..5d5a2a70d 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -193,14 +193,13 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(DeclareExternalStatement& node) { - OptVal(node.varIndex); - ExprValue(node.bindingSet); Container(node.externalVars); for (auto& extVar : node.externalVars) { Value(extVar.name); + OptVal(extVar.varIndex); ExprValue(extVar.type); ExprValue(extVar.bindingIndex); ExprValue(extVar.bindingSet); diff --git a/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp new file mode 100644 index 000000000..9fccbceac --- /dev/null +++ b/src/Nazara/Shader/Ast/EliminateUnusedPassVisitor.cpp @@ -0,0 +1,308 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + namespace + { + template T& Retrieve(std::unordered_map& map, std::size_t id) + { + auto it = map.find(id); + assert(it != map.end()); + return it->second; + } + + template const T& Retrieve(const std::unordered_map& map, std::size_t id) + { + auto it = map.find(id); + assert(it != map.end()); + return it->second; + } + + struct UsageChecker : AstRecursiveVisitor + { + struct UsageSet; + + void Resolve() + { + Resolve(globalUsage); + } + + void Resolve(const UsageSet& usageSet) + { + resolvedUsage.usedFunctions |= usageSet.usedFunctions; + resolvedUsage.usedStructs |= usageSet.usedStructs; + resolvedUsage.usedVariables |= usageSet.usedVariables; + + for (std::size_t funcIndex = usageSet.usedFunctions.FindFirst(); funcIndex != usageSet.usedFunctions.npos; funcIndex = usageSet.usedFunctions.FindNext(funcIndex)) + Resolve(Retrieve(functionUsages, funcIndex)); + + for (std::size_t structIndex = usageSet.usedStructs.FindFirst(); structIndex != usageSet.usedStructs.npos; structIndex = usageSet.usedStructs.FindNext(structIndex)) + Resolve(Retrieve(structUsages, structIndex)); + + for (std::size_t varIndex = usageSet.usedVariables.FindFirst(); varIndex != usageSet.usedVariables.npos; varIndex = usageSet.usedVariables.FindNext(varIndex)) + Resolve(Retrieve(variableUsages, varIndex)); + } + + using AstRecursiveVisitor::Visit; + + void Visit(CallFunctionExpression& node) override + { + const auto& targetFuncType = GetExpressionType(node); + assert(std::holds_alternative(targetFuncType)); + + const auto& funcType = std::get(targetFuncType); + + assert(currentFunctionIndex); + UsageSet& usageSet = Retrieve(functionUsages, *currentFunctionIndex); + usageSet.usedFunctions.UnboundedSet(funcType.funcIndex); + } + + void Visit(DeclareExternalStatement& node) override + { + for (const auto& externalVar : node.externalVars) + { + assert(externalVar.varIndex); + std::size_t varIndex = *externalVar.varIndex; + + assert(variableUsages.find(varIndex) == variableUsages.end()); + UsageSet& usageSet = variableUsages[varIndex]; + + const auto& exprType = externalVar.type.GetResultingValue(); + + if (IsUniformType(exprType)) + { + const UniformType& uniformType = std::get(exprType); + usageSet.usedStructs.UnboundedSet(uniformType.containedType.structIndex); + } + + ++varIndex; + } + + AstRecursiveVisitor::Visit(node); + } + + void Visit(DeclareFunctionStatement& node) override + { + assert(node.funcIndex); + assert(functionUsages.find(*node.funcIndex) == functionUsages.end()); + UsageSet& usageSet = functionUsages[*node.funcIndex]; + + // Register struct used in parameters or return type + if (!node.parameters.empty()) + { + assert(node.varIndex); + std::size_t parameterVarIndex = *node.varIndex; + for (auto& parameter : node.parameters) + { + // Since parameters must always be defined, their type isn't a dependency of parameter variables + assert(variableUsages.find(parameterVarIndex) == variableUsages.end()); + variableUsages.emplace(parameterVarIndex, UsageSet{}); + + const auto& exprType = parameter.type.GetResultingValue(); + if (IsStructType(exprType)) + { + std::size_t structIndex = std::get(exprType).structIndex; + usageSet.usedStructs.UnboundedSet(structIndex); + } + + ++parameterVarIndex; + } + } + + if (node.returnType.HasValue()) + { + const auto& returnExprType = node.returnType.GetResultingValue(); + if (IsStructType(returnExprType)) + { + std::size_t structIndex = std::get(returnExprType).structIndex; + usageSet.usedStructs.UnboundedSet(structIndex); + } + } + + if (node.entryStage.HasValue()) + globalUsage.usedFunctions.UnboundedSet(*node.funcIndex); + + currentFunctionIndex = node.funcIndex; + AstRecursiveVisitor::Visit(node); + currentFunctionIndex = {}; + } + + void Visit(DeclareStructStatement& node) override + { + assert(node.structIndex); + assert(structUsages.find(*node.structIndex) == structUsages.end()); + UsageSet& usageSet = structUsages[*node.structIndex]; + + for (const auto& structMember : node.description.members) + { + const auto& memberExprType = structMember.type.GetResultingValue(); + if (IsStructType(memberExprType)) + { + std::size_t structIndex = std::get(memberExprType).structIndex; + usageSet.usedStructs.UnboundedSet(structIndex); + } + } + + AstRecursiveVisitor::Visit(node); + } + + void Visit(DeclareVariableStatement& node) override + { + assert(node.varIndex); + assert(variableUsages.find(*node.varIndex) == variableUsages.end()); + UsageSet& usageSet = variableUsages[*node.varIndex]; + + const auto& varType = node.varType.GetResultingValue(); + if (IsStructType(varType)) + { + const auto& structType = std::get(varType); + usageSet.usedStructs.UnboundedSet(structType.structIndex); + } + + currentVariableDeclIndex = node.varIndex; + AstRecursiveVisitor::Visit(node); + currentVariableDeclIndex = {}; + } + + void Visit(VariableExpression& node) override + { + assert(currentFunctionIndex); + if (currentVariableDeclIndex) + { + UsageSet& usageSet = Retrieve(variableUsages, *currentVariableDeclIndex); + usageSet.usedVariables.UnboundedSet(node.variableId); + } + else + { + UsageSet& usageSet = Retrieve(functionUsages, *currentFunctionIndex); + usageSet.usedVariables.UnboundedSet(node.variableId); + } + } + + struct UsageSet + { + Bitset<> usedFunctions; + Bitset<> usedStructs; + Bitset<> usedVariables; + }; + + std::optional currentFunctionIndex; + std::optional currentVariableDeclIndex; + std::unordered_map functionUsages; + std::unordered_map structUsages; + std::unordered_map variableUsages; + UsageSet globalUsage; + UsageSet resolvedUsage; + }; + } + + struct EliminateUnusedPassVisitor::Context + { + UsageChecker usageChecker; + }; + + StatementPtr EliminateUnusedPassVisitor::Process(Statement& statement) + { + Context context; + statement.Visit(context.usageChecker); + context.usageChecker.Resolve(); + + m_context = &context; + CallOnExit onExit([this]() + { + m_context = nullptr; + }); + + return Clone(statement); + } + + StatementPtr EliminateUnusedPassVisitor::Clone(DeclareExternalStatement& node) + { + bool isUsed = false; + for (const auto& externalVar : node.externalVars) + { + assert(externalVar.varIndex); + std::size_t varIndex = *externalVar.varIndex; + + if (IsVariableUsed(varIndex)) + { + isUsed = true; + break; + } + } + + if (!isUsed) + return ShaderBuilder::NoOp(); + + auto clonedNode = AstCloner::Clone(node); + + auto& externalStatement = static_cast(*clonedNode); + for (auto it = externalStatement.externalVars.begin(); it != externalStatement.externalVars.end(); ) + { + const auto& externalVar = *it; + assert(externalVar.varIndex); + std::size_t varIndex = *externalVar.varIndex; + + if (!IsVariableUsed(varIndex)) + it = externalStatement.externalVars.erase(it); + else + ++it; + } + + return clonedNode; + } + + StatementPtr EliminateUnusedPassVisitor::Clone(DeclareFunctionStatement& node) + { + assert(node.funcIndex); + if (!IsFunctionUsed(*node.funcIndex)) + return ShaderBuilder::NoOp(); + + return AstCloner::Clone(node); + } + + StatementPtr EliminateUnusedPassVisitor::Clone(DeclareStructStatement& node) + { + assert(node.structIndex); + if (!IsStructUsed(*node.structIndex)) + return ShaderBuilder::NoOp(); + + return AstCloner::Clone(node); + } + + StatementPtr EliminateUnusedPassVisitor::Clone(DeclareVariableStatement& node) + { + assert(node.varIndex); + if (!IsVariableUsed(*node.varIndex)) + return ShaderBuilder::NoOp(); + + return AstCloner::Clone(node); + } + + bool EliminateUnusedPassVisitor::IsFunctionUsed(std::size_t varIndex) const + { + assert(m_context); + return m_context->usageChecker.resolvedUsage.usedFunctions.UnboundedTest(varIndex); + } + + bool EliminateUnusedPassVisitor::IsStructUsed(std::size_t varIndex) const + { + assert(m_context); + return m_context->usageChecker.resolvedUsage.usedStructs.UnboundedTest(varIndex); + } + + bool EliminateUnusedPassVisitor::IsVariableUsed(std::size_t varIndex) const + { + assert(m_context); + return m_context->usageChecker.resolvedUsage.usedVariables.UnboundedTest(varIndex); + } +} diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 741f1c784..eb9b14bd1 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -786,10 +786,7 @@ namespace Nz::ShaderAst throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; extVar.type = std::move(resolvedType); - - std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType)); - if (!clone->varIndex) - clone->varIndex = varIndex; //< First external variable index is node variable index + extVar.varIndex = RegisterVariable(extVar.name, std::move(varType)); SanitizeIdentifier(extVar.name); } diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 315d6b31a..34935cf45 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -177,7 +178,11 @@ namespace Nz ShaderAst::StatementPtr optimizedAst; if (states.optimize) { - optimizedAst = ShaderAst::Optimize(*targetAst); + ShaderAst::StatementPtr tempAst; + + tempAst = ShaderAst::Optimize(*targetAst); + optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst); + targetAst = optimizedAst.get(); } @@ -992,9 +997,6 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node) { - assert(node.varIndex); - std::size_t varIndex = *node.varIndex; - for (const auto& externalVar : node.externalVars) { bool isStd140 = false; @@ -1075,7 +1077,8 @@ namespace Nz if (IsUniformType(externalVar.type.GetResultingValue())) AppendLine(); - RegisterVariable(varIndex++, externalVar.name); + assert(externalVar.varIndex); + RegisterVariable(*externalVar.varIndex, externalVar.name); } } diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 43d1b750c..a9143d467 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -787,9 +787,6 @@ namespace Nz void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node) { - assert(node.varIndex); - std::size_t varIndex = *node.varIndex; - AppendLine("external"); EnterScope(); @@ -804,7 +801,8 @@ namespace Nz AppendAttributes(false, SetAttribute{ externalVar.bindingSet }, BindingAttribute{ externalVar.bindingIndex }); Append(externalVar.name, ": ", externalVar.type); - RegisterVariable(varIndex++, externalVar.name); + assert(externalVar.varIndex); + RegisterVariable(*externalVar.varIndex, externalVar.name); } LeaveScope(); diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 381b3c555..c3bbb2a00 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -579,11 +579,11 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node) { - assert(node.varIndex); - - std::size_t varIndex = *node.varIndex; for (auto&& extVar : node.externalVars) - RegisterExternalVariable(varIndex++, extVar.type.GetResultingValue()); + { + assert(extVar.varIndex); + RegisterExternalVariable(*extVar.varIndex, extVar.type.GetResultingValue()); + } } void SpirvAstVisitor::Visit(ShaderAst::DeclareFunctionStatement& node) diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index a1414e254..6a62f384b 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -137,8 +138,6 @@ namespace Nz void Visit(ShaderAst::DeclareExternalStatement& node) override { - assert(node.varIndex); - std::size_t varIndex = *node.varIndex; for (auto& extVar : node.externalVars) { SpirvConstantCache::Variable variable; @@ -165,7 +164,8 @@ namespace Nz assert(extVar.bindingIndex.IsResultingValue()); - UniformVar& uniformVar = extVars[varIndex++]; + assert(extVar.varIndex); + UniformVar& uniformVar = extVars[*extVar.varIndex]; uniformVar.pointerId = m_constantCache.Register(variable); uniformVar.bindingIndex = extVar.bindingIndex.GetResultingValue(); uniformVar.descriptorSet = (extVar.bindingSet.HasValue()) ? extVar.bindingSet.GetResultingValue() : 0; @@ -519,7 +519,11 @@ namespace Nz ShaderAst::StatementPtr optimizedAst; if (states.optimize) { - optimizedAst = ShaderAst::Optimize(*targetAst); + ShaderAst::StatementPtr tempAst; + + tempAst = ShaderAst::Optimize(*targetAst); + optimizedAst = ShaderAst::EliminateUnusedPass(*tempAst); + targetAst = optimizedAst.get(); } diff --git a/src/ShaderNode/Widgets/CodeOutputWidget.cpp b/src/ShaderNode/Widgets/CodeOutputWidget.cpp index 7ea169243..bdc7c11c4 100644 --- a/src/ShaderNode/Widgets/CodeOutputWidget.cpp +++ b/src/ShaderNode/Widgets/CodeOutputWidget.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -16,7 +17,7 @@ enum class OutputLanguage { GLSL, - Nazalang, + NZSL, SpirV }; @@ -27,7 +28,7 @@ m_shaderGraph(shaderGraph) m_outputLang = new QComboBox; m_outputLang->addItem("GLSL", int(OutputLanguage::GLSL)); - m_outputLang->addItem("Nazalang", int(OutputLanguage::Nazalang)); + m_outputLang->addItem("NZSL", int(OutputLanguage::NZSL)); m_outputLang->addItem("SPIR-V", int(OutputLanguage::SpirV)); connect(m_outputLang, qOverload(&QComboBox::currentIndexChanged), [this](int) { @@ -70,7 +71,8 @@ void CodeOutputWidget::Refresh() shaderAst = Nz::ShaderAst::Sanitize(*shaderAst, sanitizeOptions); Nz::ShaderAst::AstOptimizer optimiser; - shaderAst = optimiser.Optimise(*shaderAst); + shaderAst = Nz::ShaderAst::Optimize(*shaderAst); + shaderAst = Nz::ShaderAst::EliminateUnusedPass(*shaderAst); } std::string output; @@ -91,7 +93,7 @@ void CodeOutputWidget::Refresh() break; } - case OutputLanguage::Nazalang: + case OutputLanguage::NZSL: { Nz::LangWriter writer; output = writer.Generate(*shaderAst, states);