From 36aea2ca0c26ac1ff3fe32fd8cf2f87ecf5accb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Mon, 28 Feb 2022 13:30:53 +0100 Subject: [PATCH] Shader/SanitizeVisitor: Fix sanitization of already sanitized AST with holes in indices this happens when you sanitize an AST that went through a remove unused pass --- include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 14 +- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 165 ++++++++++-------- 2 files changed, 100 insertions(+), 79 deletions(-) diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 0a1ad1695..0e443ef42 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -110,13 +110,13 @@ namespace Nz::ShaderAst void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen); void RegisterBuiltin(); - std::size_t RegisterConstant(std::string name, ConstantValue value); - std::size_t RegisterFunction(std::string name, FunctionData funcData); - std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); - std::size_t RegisterStruct(std::string name, StructDescription* description); - std::size_t RegisterType(std::string name, ExpressionType expressionType); - std::size_t RegisterType(std::string name, PartialType partialType); - std::size_t RegisterVariable(std::string name, ExpressionType type); + std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional index = {}); + std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional index = {}); + std::size_t RegisterIntrinsic(std::string name, IntrinsicType type, std::optional index = {}); + std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional index = {}); + std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional index = {}); + std::size_t RegisterType(std::string name, PartialType partialType, std::optional index = {}); + std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional index = {}); void ResolveFunctions(); const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 4d67f5461..5a81b883e 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -42,6 +42,47 @@ namespace Nz::ShaderAst FunctionFlags flags; }; + template + struct IdentifierData + { + Bitset availableIndices; + std::unordered_map values; + + template + std::size_t Register(U&& data, std::optional index = {}) + { + std::size_t dataIndex; + if (index.has_value()) + dataIndex = *index; + else + { + dataIndex = availableIndices.FindFirst(); + if (dataIndex == availableIndices.npos) + dataIndex = availableIndices.GetSize(); + } + + if (dataIndex >= availableIndices.GetSize()) + availableIndices.Resize(dataIndex + 1, true); + else if (!availableIndices.Test(dataIndex)) + throw AstError{ "index " + std::to_string(dataIndex) + " is already used" }; + + assert(values.find(dataIndex) == values.end()); + + availableIndices.Set(dataIndex, false); + values.emplace(dataIndex, std::forward(data)); + return dataIndex; + } + + T& Retrieve(std::size_t index) + { + auto it = values.find(index); + if (it == values.end()) + throw AstError{ "invalid index " + std::to_string(index) }; + + return it->second; + } + }; + struct PendingFunction { DeclareFunctionStatement* cloneNode; @@ -56,16 +97,16 @@ namespace Nz::ShaderAst std::size_t nextOptionIndex = 0; Options options; std::array entryFunctions = {}; + std::vector identifiersInScope; + std::vector pendingFunctions; std::unordered_set declaredExternalVar; std::unordered_set usedBindingIndexes; - std::vector constantValues; - std::vector functions; - std::vector identifiersInScope; - std::vector intrinsics; - std::vector pendingFunctions; - std::vector structs; - std::vector> types; - std::vector variableTypes; + IdentifierData constantValues; + IdentifierData functions; + IdentifierData intrinsics; + IdentifierData structs; + IdentifierData> types; + IdentifierData variableTypes; std::vector scopes; CurrentFunctionData* currentFunction = nullptr; std::vector* currentStatementList = nullptr; @@ -185,8 +226,7 @@ namespace Nz::ShaderAst else if (IsStructType(exprType)) { std::size_t structIndex = ResolveStruct(exprType); - assert(structIndex < m_context->structs.size()); - const StructDescription* s = m_context->structs[structIndex]; + const StructDescription* s = m_context->structs.Retrieve(structIndex); // Retrieve member index (not counting disabled fields) Int32 fieldIndex = 0; @@ -513,11 +553,8 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { - if (node.constantId >= m_context->constantValues.size()) - throw AstError{ "invalid constant index " + std::to_string(node.constantId) }; - // Replace by constant value - auto constant = ShaderBuilder::Constant(m_context->constantValues[node.constantId]); + auto constant = ShaderBuilder::Constant(m_context->constantValues.Retrieve(node.constantId)); constant->cachedExpressionType = GetExpressionType(constant->value); return constant; @@ -552,8 +589,7 @@ namespace Nz::ShaderAst case Identifier::Type::Intrinsic: { - assert(identifier->index < m_context->intrinsics.size()); - IntrinsicType intrinsicType = m_context->intrinsics[identifier->index]; + IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifier->index); auto clone = AstCloner::Clone(node); clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; @@ -581,7 +617,7 @@ namespace Nz::ShaderAst { // Replace IdentifierExpression by VariableExpression auto varExpr = std::make_unique(); - varExpr->cachedExpressionType = m_context->variableTypes[identifier->index]; + varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifier->index); varExpr->variableId = identifier->index; return varExpr; @@ -732,7 +768,7 @@ namespace Nz::ShaderAst clone->type = expressionType; - clone->constIndex = RegisterConstant(clone->name, value); + clone->constIndex = RegisterConstant(clone->name, value, clone->constIndex); if (m_context->options.removeConstDeclaration) return ShaderBuilder::NoOp(); @@ -786,7 +822,7 @@ namespace Nz::ShaderAst throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; extVar.type = std::move(resolvedType); - extVar.varIndex = RegisterVariable(extVar.name, std::move(varType)); + extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex); SanitizeIdentifier(extVar.name); } @@ -858,7 +894,7 @@ namespace Nz::ShaderAst FunctionData funcData; funcData.node = clone.get(); //< update function node - std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData)); + std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData), node.funcIndex); clone->funcIndex = funcIndex; SanitizeIdentifier(clone->name); @@ -883,9 +919,9 @@ namespace Nz::ShaderAst std::size_t optionIndex = m_context->nextOptionIndex++; if (auto optionValueIt = m_context->options.optionValues.find(optionIndex); optionValueIt != m_context->options.optionValues.end()) - clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second); + clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, clone->optIndex); else if (clone->defaultValue) - clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue)); + clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), clone->optIndex); else throw AstError{ "missing option " + clone->optName + " value (has no default value)" }; @@ -931,7 +967,7 @@ namespace Nz::ShaderAst else if (IsStructType(resolvedType)) { std::size_t structIndex = std::get(resolvedType).structIndex; - const StructDescription* desc = m_context->structs[structIndex]; + const StructDescription* desc = m_context->structs.Retrieve(structIndex); if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) throw AstError{ "inner struct layout mismatch" }; } @@ -940,7 +976,7 @@ namespace Nz::ShaderAst member.type = std::move(resolvedType); } - clone->structIndex = RegisterStruct(clone->description.name, &clone->description); + clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex); SanitizeIdentifier(clone->description.name); @@ -1347,7 +1383,7 @@ namespace Nz::ShaderAst switch (identifier->type) { case Identifier::Type::Constant: - return m_context->constantValues[identifier->index]; + return m_context->constantValues.Retrieve(identifier->index); case Identifier::Type::Struct: return StructType{ identifier->index }; @@ -1356,7 +1392,7 @@ namespace Nz::ShaderAst return std::visit([&](auto&& arg) -> TypeParameter { return arg; - }, m_context->types[identifier->index]); + }, m_context->types.Retrieve(identifier->index)); case Identifier::Type::Alias: throw std::runtime_error("TODO"); @@ -1469,8 +1505,7 @@ namespace Nz::ShaderAst AstConstantPropagationVisitor::Options optimizerOptions; optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue& { - assert(constantId < m_context->constantValues.size()); - return m_context->constantValues[constantId]; + return m_context->constantValues.Retrieve(constantId); }; // Run optimizer on constant value to hopefully retrieve a single constant value @@ -1479,8 +1514,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) { - assert(funcIndex < m_context->functions.size()); - auto& funcData = m_context->functions[funcIndex]; + auto& funcData = m_context->functions.Retrieve(funcIndex); funcData.flags |= flags; for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) @@ -1651,13 +1685,12 @@ namespace Nz::ShaderAst RegisterIntrinsic("reflect", IntrinsicType::Reflect); } - std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value) + std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t constantIndex = m_context->constantValues.size(); - m_context->constantValues.emplace_back(std::move(value)); + std::size_t constantIndex = m_context->constantValues.Register(std::move(value), index); m_context->identifiersInScope.push_back({ std::move(name), @@ -1668,7 +1701,7 @@ namespace Nz::ShaderAst return constantIndex; } - std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData) + std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData, std::optional index) { if (auto* identifier = FindIdentifier(name)) { @@ -1677,7 +1710,7 @@ namespace Nz::ShaderAst // Functions cannot be declared twice, except for entry ones if their stages are different if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function) { - auto& otherFunction = m_context->functions[identifier->index]; + auto& otherFunction = m_context->functions.Retrieve(identifier->index); if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) duplicate = false; } @@ -1686,9 +1719,7 @@ namespace Nz::ShaderAst throw AstError{ funcData.node->name + " is already used" }; } - std::size_t functionIndex = m_context->functions.size(); - - m_context->functions.emplace_back(std::move(funcData)); + std::size_t functionIndex = m_context->functions.Register(std::move(funcData), index); m_context->identifiersInScope.push_back({ std::move(name), @@ -1699,13 +1730,12 @@ namespace Nz::ShaderAst return functionIndex; } - std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type) + std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t intrinsicIndex = m_context->intrinsics.size(); - m_context->intrinsics.push_back(type); + std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type), index); m_context->identifiersInScope.push_back({ std::move(name), @@ -1716,13 +1746,12 @@ namespace Nz::ShaderAst return intrinsicIndex; } - std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description) + std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t structIndex = m_context->structs.size(); - m_context->structs.emplace_back(description); + std::size_t structIndex = m_context->structs.Register(description, index); m_context->identifiersInScope.push_back({ std::move(name), @@ -1733,13 +1762,12 @@ namespace Nz::ShaderAst return structIndex; } - std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType) + std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t typeIndex = m_context->types.size(); - m_context->types.emplace_back(std::move(expressionType)); + std::size_t typeIndex = m_context->types.Register(std::move(expressionType), index); m_context->identifiersInScope.push_back({ std::move(name), @@ -1750,13 +1778,12 @@ namespace Nz::ShaderAst return typeIndex; } - std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType) + std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t typeIndex = m_context->types.size(); - m_context->types.emplace_back(std::move(partialType)); + std::size_t typeIndex = m_context->types.Register(std::move(partialType), index); m_context->identifiersInScope.push_back({ std::move(name), @@ -1767,7 +1794,7 @@ namespace Nz::ShaderAst return typeIndex; } - std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type) + std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type, std::optional index) { if (auto* identifier = FindIdentifier(name)) { @@ -1776,8 +1803,7 @@ namespace Nz::ShaderAst throw AstError{ name + " is already used" }; } - std::size_t varIndex = m_context->variableTypes.size(); - m_context->variableTypes.emplace_back(std::move(type)); + std::size_t varIndex = m_context->variableTypes.Register(std::move(type), index); m_context->identifiersInScope.push_back({ std::move(name), @@ -1795,13 +1821,16 @@ namespace Nz::ShaderAst { PushScope(); + std::optional varIndex = pendingFunc.cloneNode->varIndex; for (auto& parameter : pendingFunc.cloneNode->parameters) { - std::size_t varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue()); + std::size_t index = RegisterVariable(parameter.name, parameter.type.GetResultingValue(), varIndex); if (!pendingFunc.cloneNode->varIndex) - pendingFunc.cloneNode->varIndex = varIndex; //< First parameter variable index is node variable index + pendingFunc.cloneNode->varIndex = index; //< First parameter variable index is node variable index SanitizeIdentifier(parameter.name); + if (varIndex) + (*varIndex)++; } Context::CurrentFunctionData tempFuncData; @@ -1823,8 +1852,7 @@ namespace Nz::ShaderAst std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex; for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) { - assert(i < m_context->functions.size()); - auto& targetFunc = m_context->functions[i]; + auto& targetFunc = m_context->functions.Retrieve(i); targetFunc.calledByFunctions.UnboundedSet(funcIndex); } @@ -1832,15 +1860,13 @@ namespace Nz::ShaderAst } Bitset<> seen; - for (std::size_t funcIndex = 0; funcIndex < m_context->functions.size(); ++funcIndex) + for (const auto& [funcIndex, funcData] : m_context->functions.values) { - auto& funcData = m_context->functions[funcIndex]; - PropagateFunctionFlags(funcIndex, funcData.flags, seen); seen.Clear(); } - for (const FunctionData& funcData : m_context->functions) + for (const auto& [funcIndex, funcData] : m_context->functions.values) { if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) throw AstError{ "discard can only be used in the fragment stage" }; @@ -1919,7 +1945,7 @@ namespace Nz::ShaderAst std::size_t typeIndex = std::get(exprType).typeIndex; - const auto& type = m_context->types[typeIndex]; + const auto& type = m_context->types.Retrieve(typeIndex); if (std::holds_alternative(type)) throw AstError{ "full type expected" }; @@ -1991,7 +2017,7 @@ namespace Nz::ShaderAst if (IsTypeExpression(exprType)) { std::size_t typeIndex = std::get(exprType).typeIndex; - const auto& type = m_context->types[typeIndex]; + const auto& type = m_context->types.Retrieve(typeIndex); if (!std::holds_alternative(type)) throw std::runtime_error("only partial types can be specialized"); @@ -2084,8 +2110,7 @@ namespace Nz::ShaderAst Int32 index = std::get(constantExpr.value); std::size_t structIndex = ResolveStruct(exprType); - assert(structIndex < m_context->structs.size()); - const StructDescription* s = m_context->structs[structIndex]; + const StructDescription* s = m_context->structs.Retrieve(structIndex); exprType = ResolveType(s->members[index].type); } @@ -2159,8 +2184,7 @@ namespace Nz::ShaderAst assert(std::holds_alternative(targetFuncType)); std::size_t targetFuncIndex = std::get(targetFuncType).funcIndex; - assert(targetFuncIndex < m_context->functions.size()); - auto& funcData = m_context->functions[targetFuncIndex]; + auto& funcData = m_context->functions.Retrieve(targetFuncIndex); const DeclareFunctionStatement* referenceDeclaration = funcData.node; @@ -2271,7 +2295,7 @@ namespace Nz::ShaderAst TypeMustMatch(resolvedType, GetExpressionType(*node.initialExpression)); } - node.varIndex = RegisterVariable(node.varName, resolvedType); + node.varIndex = RegisterVariable(node.varName, resolvedType, node.varIndex); node.varType = std::move(resolvedType); if (m_context->options.makeVariableNameUnique) @@ -2523,10 +2547,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(VariableExpression& node) { - if (node.variableId >= m_context->variableTypes.size()) - throw AstError{ "invalid constant index " + std::to_string(node.variableId) }; - - node.cachedExpressionType = m_context->variableTypes[node.variableId]; + node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId); } ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)