diff --git a/include/Nazara/Graphics/UberShader.hpp b/include/Nazara/Graphics/UberShader.hpp index 7cd1f4c4f..1205ece7c 100644 --- a/include/Nazara/Graphics/UberShader.hpp +++ b/include/Nazara/Graphics/UberShader.hpp @@ -65,7 +65,7 @@ namespace Nz NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/); private: - void Validate(ShaderAst::Module& module); + ShaderAst::ModulePtr Validate(const ShaderAst::Module& module, std::unordered_map* options); NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated); diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index 9b1a0de64..445eba834 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -57,6 +57,7 @@ namespace Nz::ShaderAst virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node); virtual ExpressionPtr Clone(StructTypeExpression& node); virtual ExpressionPtr Clone(SwizzleExpression& node); + virtual ExpressionPtr Clone(TypeExpression& node); virtual ExpressionPtr Clone(VariableValueExpression& node); virtual ExpressionPtr Clone(UnaryExpression& node); diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp index dc6d0dcac..58613d77b 100644 --- a/include/Nazara/Shader/Ast/AstCompare.hpp +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -49,6 +49,7 @@ namespace Nz::ShaderAst inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs); inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs); inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs); + inline bool Compare(const TypeExpression& lhs, const TypeExpression& rhs); inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs); inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs); diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl index a0990be8b..bcb46a28e 100644 --- a/include/Nazara/Shader/Ast/AstCompare.inl +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -407,6 +407,14 @@ namespace Nz::ShaderAst return true; } + bool Compare(const TypeExpression& lhs, const TypeExpression& rhs) + { + if (!Compare(lhs.typeId, rhs.typeId)) + return false; + + return true; + } + inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs) { if (!Compare(lhs.variableId, rhs.variableId)) diff --git a/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp b/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp index 0bcc99e3b..d6e3bbf42 100644 --- a/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp @@ -36,7 +36,7 @@ namespace Nz::ShaderAst struct Options { - std::function constantQueryCallback; + std::function constantQueryCallback; }; protected: diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index ed23b3faa..999aa6e67 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression) NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression) NAZARA_SHADERAST_EXPRESSION(StructTypeExpression) NAZARA_SHADERAST_EXPRESSION(SwizzleExpression) +NAZARA_SHADERAST_EXPRESSION(TypeExpression) NAZARA_SHADERAST_EXPRESSION(VariableValueExpression) NAZARA_SHADERAST_EXPRESSION(UnaryExpression) NAZARA_SHADERAST_STATEMENT(BranchStatement) diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 23954126f..e297be085 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -37,6 +37,7 @@ namespace Nz::ShaderAst void Visit(IntrinsicFunctionExpression& node) override; void Visit(StructTypeExpression& node) override; void Visit(SwizzleExpression& node) override; + void Visit(TypeExpression& node) override; void Visit(VariableValueExpression& node) override; void Visit(UnaryExpression& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index fcc673f00..dba8f483c 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -40,6 +40,7 @@ namespace Nz::ShaderAst void Serialize(IntrinsicFunctionExpression& node); void Serialize(StructTypeExpression& node); void Serialize(SwizzleExpression& node); + void Serialize(TypeExpression& node); void Serialize(VariableValueExpression& node); void Serialize(UnaryExpression& node); void SerializeExpressionCommon(Expression& expr); diff --git a/include/Nazara/Shader/Ast/AstUtils.hpp b/include/Nazara/Shader/Ast/AstUtils.hpp index 9e8271193..f92b192e1 100644 --- a/include/Nazara/Shader/Ast/AstUtils.hpp +++ b/include/Nazara/Shader/Ast/AstUtils.hpp @@ -48,6 +48,7 @@ namespace Nz::ShaderAst void Visit(IntrinsicFunctionExpression& node) override; void Visit(StructTypeExpression& node) override; void Visit(SwizzleExpression& node) override; + void Visit(TypeExpression& node) override; void Visit(VariableValueExpression& node) override; void Visit(UnaryExpression& node) override; diff --git a/include/Nazara/Shader/Ast/ConstantValue.hpp b/include/Nazara/Shader/Ast/ConstantValue.hpp index 07a1e5e19..cdf30b8f9 100644 --- a/include/Nazara/Shader/Ast/ConstantValue.hpp +++ b/include/Nazara/Shader/Ast/ConstantValue.hpp @@ -37,7 +37,7 @@ namespace Nz::ShaderAst using ConstantValue = TypeListInstantiate; - NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant); + NAZARA_SHADER_API ExpressionType GetConstantType(const ConstantValue& constant); } #endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 3ffdb5c1a..2a6ca6c2d 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -215,6 +215,14 @@ namespace Nz::ShaderAst ExpressionPtr expression; }; + struct NAZARA_SHADER_API TypeExpression : Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::size_t typeId; + }; + struct NAZARA_SHADER_API VariableValueExpression : Expression { NodeType GetType() const override; @@ -462,8 +470,8 @@ namespace Nz::ShaderAst #include - inline const ExpressionType& GetExpressionType(Expression& expr); - inline ExpressionType& GetExpressionTypeMut(Expression& expr); + inline const ExpressionType* GetExpressionType(Expression& expr); + inline ExpressionType* GetExpressionTypeMut(Expression& expr); inline bool IsExpression(NodeType nodeType); inline bool IsStatement(NodeType nodeType); diff --git a/include/Nazara/Shader/Ast/Nodes.inl b/include/Nazara/Shader/Ast/Nodes.inl index 4d1ef740c..a14ecc4f4 100644 --- a/include/Nazara/Shader/Ast/Nodes.inl +++ b/include/Nazara/Shader/Ast/Nodes.inl @@ -7,16 +7,14 @@ namespace Nz::ShaderAst { - inline const ExpressionType& GetExpressionType(Expression& expr) + inline const ExpressionType* GetExpressionType(Expression& expr) { - assert(expr.cachedExpressionType); - return expr.cachedExpressionType.value(); + return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr; } - inline ExpressionType& GetExpressionTypeMut(Expression& expr) + inline ExpressionType* GetExpressionTypeMut(Expression& expr) { - assert(expr.cachedExpressionType); - return expr.cachedExpressionType.value(); + return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr; } inline const ExpressionType& ResolveAlias(const ExpressionType& exprType) diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 0fda60941..4862547b6 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -46,6 +46,7 @@ namespace Nz::ShaderAst std::shared_ptr moduleResolver; std::unordered_set reservedIdentifiers; std::unordered_map optionValues; + bool allowPartialSanitization = false; bool makeVariableNameUnique = false; bool reduceLoopsToWhile = false; bool removeAliases = false; @@ -60,6 +61,7 @@ namespace Nz::ShaderAst private: enum class IdentifierCategory; + enum class ValidationResult; struct AstError; struct CurrentFunctionData; struct Environment; @@ -110,10 +112,12 @@ namespace Nz::ShaderAst template const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const; const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const; template const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; - TypeParameter FindTypeParameter(const std::string_view& identifierName) const; ExpressionPtr HandleIdentifier(const IdentifierData* identifierData); + const ExpressionType* GetExpressionType(Expression& expr) const; + const ExpressionType& GetExpressionTypeSecure(Expression& expr) const; + Expression& MandatoryExpr(const ExpressionPtr& node) const; Statement& MandatoryStatement(const StatementPtr& node) const; @@ -122,8 +126,9 @@ namespace Nz::ShaderAst ExpressionPtr CacheResult(ExpressionPtr expression); - ConstantValue ComputeConstantValue(Expression& expr) const; - template const T& ComputeExprValue(ExpressionValue& attribute) const; + std::optional ComputeConstantValue(Expression& expr) const; + template ValidationResult ComputeExprValue(ExpressionValue& attribute) const; + template ValidationResult ComputeExprValue(const ExpressionValue& attribute, ExpressionValue& targetAttribute); template std::unique_ptr PropagateConstants(T& node) const; void PreregisterIndices(const Module& module); @@ -131,49 +136,49 @@ namespace Nz::ShaderAst void RegisterBuiltin(); - std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional index = {}); - std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional index = {}); - std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional index = {}); + std::size_t RegisterAlias(std::string name, std::optional aliasData, std::optional index = {}); + std::size_t RegisterConstant(std::string name, std::optional value, std::optional index = {}); + std::size_t RegisterFunction(std::string name, std::optional funcData, std::optional index = {}); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex); - std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional index = {}); - std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional index = {}); - std::size_t RegisterType(std::string name, PartialType partialType, std::optional index = {}); - std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional index = {}); + std::size_t RegisterStruct(std::string name, std::optional description, std::optional index = {}); + std::size_t RegisterType(std::string name, std::optional expressionType, std::optional index = {}); + std::size_t RegisterType(std::string name, std::optional partialType, std::optional index = {}); + void RegisterUnresolved(std::string name); + std::size_t RegisterVariable(std::string name, std::optional type, std::optional index = {}); const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const; void ResolveFunctions(); - const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); std::size_t ResolveStruct(const AliasType& aliasType); std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const StructType& structType); std::size_t ResolveStruct(const UniformType& uniformType); ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false); - ExpressionType ResolveType(const ExpressionValue& exprTypeValue, bool resolveAlias = false); + std::optional ResolveTypeExpr(const ExpressionValue& exprTypeValue, bool resolveAlias = false); void SanitizeIdentifier(std::string& identifier); MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error); - void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const; + ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const; void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const; StatementPtr Unscope(StatementPtr node); - void Validate(DeclareAliasStatement& node); - void Validate(WhileStatement& node); + ValidationResult Validate(DeclareAliasStatement& node); + ValidationResult Validate(WhileStatement& node); - void Validate(AccessIndexExpression& node); - void Validate(AssignExpression& node); - void Validate(BinaryExpression& node); - void Validate(CallFunctionExpression& node); - void Validate(CastExpression& node); - void Validate(DeclareVariableStatement& node); - void Validate(IntrinsicExpression& node); - void Validate(SwizzleExpression& node); - void Validate(UnaryExpression& node); - void Validate(VariableValueExpression& node); - ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr); + ValidationResult Validate(AccessIndexExpression& node); + ValidationResult Validate(AssignExpression& node); + ValidationResult Validate(BinaryExpression& node); + ValidationResult Validate(CallFunctionExpression& node); + ValidationResult Validate(CastExpression& node); + ValidationResult Validate(DeclareVariableStatement& node); + ValidationResult Validate(IntrinsicExpression& node); + ValidationResult Validate(SwizzleExpression& node); + ValidationResult Validate(UnaryExpression& node); + ValidationResult Validate(VariableValueExpression& node); + ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType); enum class IdentifierCategory { @@ -184,9 +189,16 @@ namespace Nz::ShaderAst Module, Struct, Type, + Unresolved, Variable }; + enum class ValidationResult + { + Validated, + Unresolved + }; + struct FunctionData { Bitset<> calledByFunctions; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index bff5bc9a2..48841a53d 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -173,7 +173,7 @@ namespace Nz::ShaderBuilder { auto constantNode = std::make_unique(); constantNode->value = std::move(value); - constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value); + constantNode->cachedExpressionType = ShaderAst::GetConstantType(constantNode->value); return constantNode; } diff --git a/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl b/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl index 94919142c..769621143 100644 --- a/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl +++ b/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl @@ -249,7 +249,7 @@ struct VertIn } [entry(vert), cond(Billboard)] -fn billboardMain(input: VertIn) -> VertOut +fn billboardMain(input: VertIn) -> VertToFrag { let size = input.billboardSizeRot.xy; let sinCos = input.billboardSizeRot.zw; diff --git a/src/Nazara/Graphics/UberShader.cpp b/src/Nazara/Graphics/UberShader.cpp index 2dcbfa436..205738b5c 100644 --- a/src/Nazara/Graphics/UberShader.cpp +++ b/src/Nazara/Graphics/UberShader.cpp @@ -25,7 +25,7 @@ namespace Nz m_shaderModule = moduleResolver.Resolve(moduleName); NazaraAssert(m_shaderModule, "invalid shader module"); - Validate(*m_shaderModule); + m_shaderModule = Validate(*m_shaderModule, &m_optionIndexByName); m_onShaderModuleUpdated.Connect(moduleResolver.OnModuleUpdated, [this, name = std::move(moduleName)](ShaderModuleResolver* resolver, const std::string& updatedModuleName) { @@ -41,8 +41,7 @@ namespace Nz try { - // FIXME: Validate is destructive, in case of failure it can invalidate the shader - Validate(*newShaderModule); + m_shaderModule = Validate(*newShaderModule, &m_optionIndexByName); } catch (const std::exception& e) { @@ -50,8 +49,6 @@ namespace Nz return; } - m_shaderModule = std::move(newShaderModule); - // Clear cache m_combinations.clear(); @@ -65,7 +62,7 @@ namespace Nz { NazaraAssert(m_shaderModule, "invalid shader module"); - Validate(*m_shaderModule); + Validate(*m_shaderModule, &m_optionIndexByName); } const std::shared_ptr& UberShader::Get(const Config& config) @@ -85,13 +82,17 @@ namespace Nz return it->second; } - void UberShader::Validate(ShaderAst::Module& module) + ShaderAst::ModulePtr UberShader::Validate(const ShaderAst::Module& module, std::unordered_map* options) { NazaraAssert(m_shaderStages != 0, "there must be at least one shader stage"); + assert(options); - //TODO: Try to partially sanitize shader? + // Try to partially sanitize shader - std::size_t optionCount = 0; + ShaderAst::SanitizeVisitor::Options sanitizeOptions; + sanitizeOptions.allowPartialSanitization = true; + + ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module, sanitizeOptions); ShaderStageTypeFlags supportedStageType; @@ -101,21 +102,24 @@ namespace Nz supportedStageType |= stageType; }; + std::unordered_map optionByName; callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option) { //TODO: Check optionType - m_optionIndexByName[option.optName] = Option{ + optionByName[option.optName] = Option{ CRC32(option.optName) }; - - optionCount++; }; ShaderAst::AstReflect reflect; - reflect.Reflect(*module.rootNode, callbacks); + reflect.Reflect(*sanitizedModule->rootNode, callbacks); if ((m_shaderStages & supportedStageType) != m_shaderStages) throw std::runtime_error("shader doesn't support all required shader stages"); + + *options = std::move(optionByName); + + return sanitizedModule; } } diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index c522165c8..aaeea00fc 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -481,6 +481,16 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr AstCloner::Clone(TypeExpression& node) + { + auto clone = std::make_unique(); + clone->typeId = node.typeId; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + ExpressionPtr AstCloner::Clone(VariableValueExpression& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp b/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp index 4d28814d7..cd9b62c5b 100644 --- a/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp @@ -862,7 +862,7 @@ namespace Nz::ShaderAst const auto& constantExpr = static_cast(*expressions[i]); - if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value)) + if (!constantValues.empty() && GetConstantType(constantValues.front()) != GetConstantType(constantExpr.value)) { // Unhandled case, all cast parameters are expected to be of the same type constantValues.clear(); @@ -940,16 +940,24 @@ namespace Nz::ShaderAst std::vector statements; StatementPtr elseStatement; + bool continuePropagation = true; for (auto& condStatement : node.condStatements) { auto cond = CloneExpression(condStatement.condition); - if (cond->GetType() == NodeType::ConstantValueExpression) + if (continuePropagation && cond->GetType() == NodeType::ConstantValueExpression) { auto& constant = static_cast(*cond); - const ExpressionType& constantType = GetExpressionType(constant); - if (!IsPrimitiveType(constantType) || std::get(constantType) != PrimitiveType::Boolean) + const ExpressionType* constantType = GetExpressionType(constant); + if (!constantType) + { + // unresolved type, can't continue propagating this branch + continuePropagation = false; + continue; + } + + if (!IsPrimitiveType(*constantType) || std::get(*constantType) != PrimitiveType::Boolean) continue; bool cValue = std::get(constant.value); @@ -1017,8 +1025,12 @@ namespace Nz::ShaderAst if (!m_options.constantQueryCallback) return AstCloner::Clone(node); - auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId)); - constant->cachedExpressionType = GetExpressionType(constant->value); + const ConstantValue* constantValue = m_options.constantQueryCallback(node.constantId); + if (!constantValue) + return AstCloner::Clone(node); + + auto constant = ShaderBuilder::Constant(*constantValue); + constant->cachedExpressionType = GetConstantType(constant->value); return constant; } @@ -1155,7 +1167,7 @@ namespace Nz::ShaderAst }, lhs.value); if (optimized) - optimized->cachedExpressionType = GetExpressionType(optimized->value); + optimized->cachedExpressionType = GetConstantType(optimized->value); return optimized; } @@ -1221,7 +1233,7 @@ namespace Nz::ShaderAst }, operand.value); if (optimized) - optimized->cachedExpressionType = GetExpressionType(optimized->value); + optimized->cachedExpressionType = GetConstantType(optimized->value); return optimized; } diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index a1bd3717b..b14768094 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -109,6 +109,11 @@ namespace Nz::ShaderAst node.expression->Visit(*this); } + void AstRecursiveVisitor::Visit(TypeExpression& node) + { + /* Nothing to do */ + } + void AstRecursiveVisitor::Visit(VariableValueExpression& /*node*/) { /* Nothing to do */ diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index ac6d03ab2..a6e553966 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -174,6 +174,11 @@ namespace Nz::ShaderAst SizeT(node.structTypeId); } + void AstSerializerBase::Serialize(TypeExpression& node) + { + SizeT(node.typeId); + } + void AstSerializerBase::Serialize(FunctionExpression& node) { SizeT(node.funcId); diff --git a/src/Nazara/Shader/Ast/AstUtils.cpp b/src/Nazara/Shader/Ast/AstUtils.cpp index c5917a2ef..c96726cae 100644 --- a/src/Nazara/Shader/Ast/AstUtils.cpp +++ b/src/Nazara/Shader/Ast/AstUtils.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include namespace Nz::ShaderAst @@ -104,7 +105,10 @@ namespace Nz::ShaderAst void ShaderAstValueCategory::Visit(SwizzleExpression& node) { - if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1) + const ExpressionType* exprType = GetExpressionType(node); + assert(exprType); + + if (IsPrimitiveType(*exprType) && node.componentCount > 1) // Swizzling more than a component on a primitive produces a rvalue (a.xxxx cannot be assigned) m_expressionCategory = ExpressionCategory::RValue; else @@ -133,6 +137,11 @@ namespace Nz::ShaderAst } } + void ShaderAstValueCategory::Visit(TypeExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + void ShaderAstValueCategory::Visit(VariableValueExpression& /*node*/) { m_expressionCategory = ExpressionCategory::LValue; diff --git a/src/Nazara/Shader/Ast/ConstantValue.cpp b/src/Nazara/Shader/Ast/ConstantValue.cpp index 959b40700..2022a4aea 100644 --- a/src/Nazara/Shader/Ast/ConstantValue.cpp +++ b/src/Nazara/Shader/Ast/ConstantValue.cpp @@ -8,7 +8,7 @@ namespace Nz::ShaderAst { - ExpressionType GetExpressionType(const ConstantValue& constant) + ExpressionType GetConstantType(const ConstantValue& constant) { return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType { diff --git a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp index e83648c97..3000b01fc 100644 --- a/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp +++ b/src/Nazara/Shader/Ast/DependencyCheckerVisitor.cpp @@ -15,10 +15,11 @@ namespace Nz::ShaderAst void DependencyCheckerVisitor::Visit(CallFunctionExpression& node) { - const auto& targetFuncType = GetExpressionType(*node.targetFunction); - assert(std::holds_alternative(targetFuncType)); + const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction); + assert(targetFuncType); + assert(std::holds_alternative(*targetFuncType)); - const auto& funcType = std::get(targetFuncType); + const auto& funcType = std::get(*targetFuncType); assert(m_currentFunctionIndex); if (m_currentVariableDeclIndex) diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index cae9333ee..3879b0f9b 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -112,6 +112,20 @@ namespace Nz::ShaderAst return it->second; } + + T* TryRetrieve(std::size_t index) + { + auto it = values.find(index); + if (it == values.end()) + { + if (!preregisteredIndices.UnboundedTest(index)) + throw AstError{ "invalid index " + std::to_string(index) }; + + return nullptr; + } + + return &it->second; + } }; struct SanitizeVisitor::Scope @@ -165,6 +179,7 @@ namespace Nz::ShaderAst ModulePtr currentModule; Options options; CurrentFunctionData* currentFunction = nullptr; + bool allowUnknownIdentifiers = false; }; ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) @@ -266,7 +281,11 @@ namespace Nz::ShaderAst if (!exprType.HasValue()) return {}; - return ResolveType(exprType); + std::optional resolvedType = ResolveTypeExpr(exprType); + if (!resolvedType.has_value()) + return AstCloner::CloneType(exprType); + + return std::move(resolvedType).value(); } ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node) @@ -298,9 +317,13 @@ namespace Nz::ShaderAst if (identifier.empty()) throw AstError{ "empty identifier" }; - const ExpressionType& exprType = ResolveAlias(GetExpressionType(*indexedExpr)); + const ExpressionType* exprType = GetExpressionType(*indexedExpr); + if (!exprType) + return AstCloner::Clone(node); //< unresolved type + + const ExpressionType& resolvedType = ResolveAlias(*exprType); // TODO: Add proper support for methods - if (IsSamplerType(exprType)) + if (IsSamplerType(resolvedType)) { if (identifier == "Sample") { @@ -312,7 +335,7 @@ namespace Nz::ShaderAst MethodType methodType; methodType.methodIndex = 0; //< FIXME methodType.objectType = std::make_unique(); - methodType.objectType->type = exprType; + methodType.objectType->type = resolvedType; identifierExpr->cachedExpressionType = std::move(methodType); indexedExpr = std::move(identifierExpr); @@ -320,9 +343,9 @@ namespace Nz::ShaderAst else throw AstError{ "type has no method " + identifier }; } - else if (IsStructType(exprType)) + else if (IsStructType(resolvedType)) { - std::size_t structIndex = ResolveStruct(exprType); + std::size_t structIndex = ResolveStruct(resolvedType); const StructDescription* s = m_context->structs.Retrieve(structIndex); // Retrieve member index (not counting disabled fields) @@ -330,8 +353,18 @@ namespace Nz::ShaderAst const StructDescription::StructMember* fieldPtr = nullptr; for (const auto& field : s->members) { - if (field.cond.HasValue() && !field.cond.GetResultingValue()) - continue; + if (field.cond.HasValue()) + { + if (!field.cond.IsResultingValue()) + { + if (m_context->options.allowPartialSanitization) + return AstCloner::Clone(node); //< unresolved + + throw AstError{ "cond attribute is not constant" }; + } + else if (!field.cond.GetResultingValue()) + continue; + } if (field.name == identifier) { @@ -361,7 +394,7 @@ namespace Nz::ShaderAst accessIdentifierPtr = static_cast(indexedExpr.get()); accessIdentifierPtr->identifiers.push_back(fieldPtr->name); - accessIdentifierPtr->cachedExpressionType = ResolveType(fieldPtr->type); + accessIdentifierPtr->cachedExpressionType = ResolveTypeExpr(fieldPtr->type); } else { @@ -369,19 +402,19 @@ namespace Nz::ShaderAst std::unique_ptr accessIndex = std::make_unique(); accessIndex->expr = std::move(indexedExpr); accessIndex->indices.push_back(ShaderBuilder::Constant(fieldIndex)); - accessIndex->cachedExpressionType = ResolveType(fieldPtr->type); + accessIndex->cachedExpressionType = ResolveTypeExpr(fieldPtr->type); indexedExpr = std::move(accessIndex); } } - else if (IsPrimitiveType(exprType) || IsVectorType(exprType)) + else if (IsPrimitiveType(resolvedType) || IsVectorType(resolvedType)) { // Swizzle expression std::size_t swizzleComponentCount = identifier.size(); if (swizzleComponentCount > 4) throw AstError{ "cannot swizzle more than four elements" }; - if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType)) + if (m_context->options.removeScalarSwizzling && IsPrimitiveType(resolvedType)) { for (std::size_t j = 0; j < swizzleComponentCount; ++j) { @@ -396,10 +429,10 @@ namespace Nz::ShaderAst indexedExpr = CacheResult(std::move(indexedExpr)); //< Since we are going to use a value multiple times, cache it if required PrimitiveType baseType; - if (IsVectorType(exprType)) - baseType = std::get(exprType).type; + if (IsVectorType(resolvedType)) + baseType = std::get(resolvedType).type; else - baseType = std::get(exprType); + baseType = std::get(resolvedType); auto cast = std::make_unique(); cast->targetType = ExpressionType{ VectorType{ swizzleComponentCount, baseType } }; @@ -486,8 +519,11 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node) { ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction)); - const ExpressionType& targetExprType = GetExpressionType(*targetExpr); - const ExpressionType& resolvedType = ResolveAlias(targetExprType); + const ExpressionType* targetExprType = GetExpressionType(*targetExpr); + if (!targetExprType) + return AstCloner::Clone(node); //< unresolved type + + const ExpressionType& resolvedType = ResolveAlias(*targetExprType); if (IsFunctionType(resolvedType)) { @@ -565,7 +601,7 @@ namespace Nz::ShaderAst { // Calling a type - vec3[f32](0.0, 1.0, 2.0) - it's a cast auto clone = std::make_unique(); - clone->targetType = std::move(targetExprType); + clone->targetType = *targetExprType; if (node.parameters.size() > clone->expressions.size()) throw AstError{ "component count doesn't match required component count" }; @@ -582,7 +618,8 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(CastExpression& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); - Validate(*clone); + if (Validate(*clone) == ValidationResult::Unresolved) + return clone; //< unresolved const ExpressionType& targetType = clone->targetType.GetResultingValue(); @@ -590,7 +627,7 @@ namespace Nz::ShaderAst { const MatrixType& targetMatrixType = std::get(targetType); - const ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front()); + const ExpressionType& frontExprType = ResolveAlias(GetExpressionTypeSecure(*clone->expressions.front())); bool isMatrixCast = IsMatrixType(frontExprType); if (isMatrixCast && std::get(frontExprType) == targetMatrixType) { @@ -627,7 +664,7 @@ namespace Nz::ShaderAst { // parameter #i vectorExpr = std::move(clone->expressions[i]); - vectorComponentCount = std::get(GetExpressionType(*vectorExpr)).componentCount; + vectorComponentCount = std::get(ResolveAlias(GetExpressionTypeSecure(*vectorExpr))).componentCount; } // cast expression (turn fromMatrix[i] to vec3[f32](fromMatrix[i])) @@ -673,7 +710,26 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node) { - return AstCloner::Clone(*ResolveCondExpression(node)); + MandatoryExpr(node.condition); + MandatoryExpr(node.truePath); + MandatoryExpr(node.falsePath); + + ExpressionPtr cloneCondition = AstCloner::Clone(*node.condition); + + std::optional conditionValue = ComputeConstantValue(*cloneCondition); + if (!conditionValue.has_value()) + { + // Unresolvable condition + return AstCloner::Clone(node); + } + + if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) + throw AstError{ "expected a boolean value" }; + + if (std::get(*conditionValue)) + return AstCloner::Clone(*node.truePath); + else + return AstCloner::Clone(*node.falsePath); } ExpressionPtr SanitizeVisitor::Clone(ConstantValueExpression& node) @@ -682,16 +738,25 @@ namespace Nz::ShaderAst throw std::runtime_error("expected a value"); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); - clone->cachedExpressionType = GetExpressionType(clone->value); + clone->cachedExpressionType = GetConstantType(clone->value); return clone; } ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { + const ConstantValue* value = m_context->constantValues.TryRetrieve(node.constantId); + if (!value) + { + if (!m_context->options.allowPartialSanitization) + throw std::runtime_error("invalid constant index #" + std::to_string(node.constantId)); + + return AstCloner::Clone(node); //< unresolved + } + // Replace by constant value - auto constant = ShaderBuilder::Constant(m_context->constantValues.Retrieve(node.constantId)); - constant->cachedExpressionType = GetExpressionType(constant->value); + auto constant = ShaderBuilder::Constant(*value); + constant->cachedExpressionType = GetConstantType(constant->value); return constant; } @@ -702,7 +767,15 @@ namespace Nz::ShaderAst const IdentifierData* identifierData = FindIdentifier(node.identifier); if (!identifierData) + { + if (m_context->allowUnknownIdentifiers) + return AstCloner::Clone(node); + throw AstError{ "unknown identifier " + node.identifier }; + } + + if (identifierData->category == IdentifierCategory::Unresolved) + return AstCloner::Clone(node); return HandleIdentifier(identifierData); } @@ -719,8 +792,13 @@ namespace Nz::ShaderAst { auto expression = CloneExpression(MandatoryExpr(node.expression)); - const ExpressionType& exprType = GetExpressionType(*expression); - if (m_context->options.removeScalarSwizzling && IsPrimitiveType(exprType)) + const ExpressionType* exprType = GetExpressionType(*expression); + if (!exprType) + return ShaderBuilder::Swizzle(std::move(expression), node.components, node.componentCount); //< unresolved + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + + if (m_context->options.removeScalarSwizzling && IsPrimitiveType(resolvedExprType)) { for (std::size_t i = 0; i < node.componentCount; ++i) { @@ -735,10 +813,10 @@ namespace Nz::ShaderAst expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required PrimitiveType baseType; - if (IsVectorType(exprType)) - baseType = std::get(exprType).type; + if (IsVectorType(resolvedExprType)) + baseType = std::get(resolvedExprType).type; else - baseType = std::get(exprType); + baseType = std::get(resolvedExprType); auto cast = std::make_unique(); cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } }; @@ -786,11 +864,14 @@ namespace Nz::ShaderAst { MandatoryExpr(cond.condition); - ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition)); - if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean }) + std::optional conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition)); + if (!conditionValue.has_value()) + return AstCloner::Clone(node); //< Unresolvable condition + + if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) throw AstError{ "expected a boolean value" }; - if (std::get(conditionValue)) + if (std::get(*conditionValue)) return Unscope(AstCloner::Clone(*cond.statement)); } @@ -818,24 +899,32 @@ namespace Nz::ShaderAst { condStatement.condition = CloneExpression(MandatoryExpr(cond.condition)); - const ExpressionType& condType = GetExpressionType(*condStatement.condition); - if (!IsPrimitiveType(condType) || std::get(condType) != PrimitiveType::Boolean) + const ExpressionType* condType = GetExpressionType(*condStatement.condition); + if (!condType) + return ValidationResult::Unresolved; + + if (!IsPrimitiveType(*condType) || std::get(*condType) != PrimitiveType::Boolean) throw AstError{ "branch expressions must resolve to boolean type" }; condStatement.statement = CloneStatement(MandatoryStatement(cond.statement)); + return ValidationResult::Validated; }; if (m_context->options.splitMultipleBranches && condIndex > 0) { auto currentBranch = std::make_unique(); - BuildCondStatement(currentBranch->condStatements.emplace_back()); + if (BuildCondStatement(currentBranch->condStatements.emplace_back()) == ValidationResult::Unresolved) + return AstCloner::Clone(node); root->elseStatement = std::move(currentBranch); root = static_cast(root->elseStatement.get()); } else - BuildCondStatement(clone->condStatements.emplace_back()); + { + if (BuildCondStatement(clone->condStatements.emplace_back()) == ValidationResult::Unresolved) + return AstCloner::Clone(node); + } PopScope(); } @@ -855,11 +944,19 @@ namespace Nz::ShaderAst MandatoryExpr(node.condition); MandatoryStatement(node.statement); - ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition)); - if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean }) + ExpressionPtr cloneCondition = AstCloner::Clone(*node.condition); + + std::optional conditionValue = ComputeConstantValue(*cloneCondition); + if (!conditionValue.has_value()) + { + // Unresolvable condition + return ShaderBuilder::ConditionalStatement(std::move(cloneCondition), AstCloner::Clone(*node.statement)); + } + + if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) throw AstError{ "expected a boolean value" }; - if (std::get(conditionValue)) + if (std::get(*conditionValue)) return AstCloner::Clone(*node.statement); else return ShaderBuilder::NoOp(); @@ -885,13 +982,21 @@ namespace Nz::ShaderAst clone->expression = PropagateConstants(*clone->expression); if (clone->expression->GetType() != NodeType::ConstantValueExpression) - throw AstError{ "const variable must have constant expressions " }; + { + if (!m_context->options.allowPartialSanitization) + throw AstError{ "const variable must have constant expressions " }; + + clone->constIndex = RegisterConstant(clone->name, std::nullopt, clone->constIndex); + return clone; + } const ConstantValue& value = static_cast(*clone->expression).value; - ExpressionType expressionType = ResolveType(GetExpressionType(value)); + ExpressionType expressionType = GetConstantType(value); - if (clone->type.HasValue() && ResolveType(clone->type, true) != ResolveAlias(expressionType)) + std::optional constType = ResolveTypeExpr(clone->type, true); + + if (clone->type.HasValue() && constType.has_value() && *constType != ResolveAlias(expressionType)) throw AstError{ "constant expression doesn't match type" }; clone->type = expressionType; @@ -910,9 +1015,14 @@ namespace Nz::ShaderAst auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); - UInt32 defaultBlockSet = 0; + std::optional defaultBlockSet = 0; if (clone->bindingSet.HasValue()) - defaultBlockSet = ComputeExprValue(clone->bindingSet); + { + if (ComputeExprValue(clone->bindingSet) == ValidationResult::Unresolved) + defaultBlockSet = clone->bindingSet.GetResultingValue(); + else + defaultBlockSet.reset(); //< Unresolved value + } for (auto& extVar : clone->externalVars) { @@ -921,26 +1031,36 @@ namespace Nz::ShaderAst if (extVar.bindingSet.HasValue()) ComputeExprValue(extVar.bindingSet); - else - extVar.bindingSet = defaultBlockSet; + else if (defaultBlockSet) + extVar.bindingSet = *defaultBlockSet; - UInt64 bindingSet = extVar.bindingSet.GetResultingValue(); + ComputeExprValue(extVar.bindingIndex); - UInt64 bindingIndex = ComputeExprValue(extVar.bindingIndex); + if (extVar.bindingSet.IsResultingValue() && extVar.bindingIndex.IsResultingValue()) + { + UInt64 bindingSet = extVar.bindingSet.GetResultingValue(); + UInt64 bindingIndex = extVar.bindingIndex.GetResultingValue(); - UInt64 bindingKey = bindingSet << 32 | bindingIndex; - if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end()) - throw AstError{ "binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" }; + UInt64 bindingKey = bindingSet << 32 | bindingIndex; + if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end()) + throw AstError{ "binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" }; - m_context->usedBindingIndexes.insert(bindingKey); + m_context->usedBindingIndexes.insert(bindingKey); + } if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end()) throw AstError{ "external variable " + extVar.name + " is already declared" }; m_context->declaredExternalVar.insert(extVar.name); - ExpressionType resolvedType = ResolveType(extVar.type); - const ExpressionType& targetType = ResolveAlias(resolvedType); + std::optional resolvedType = ResolveTypeExpr(extVar.type); + if (!resolvedType.has_value()) + { + RegisterUnresolved(extVar.name); + continue; + } + + const ExpressionType& targetType = ResolveAlias(*resolvedType); ExpressionType varType; if (IsUniformType(targetType)) @@ -950,7 +1070,7 @@ namespace Nz::ShaderAst else throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; - extVar.type = std::move(resolvedType); + extVar.type = std::move(resolvedType).value(); extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex); SanitizeIdentifier(extVar.name); @@ -972,35 +1092,38 @@ namespace Nz::ShaderAst { auto& cloneParam = clone->parameters.emplace_back(); cloneParam.name = parameter.name; - cloneParam.type = ResolveType(parameter.type); + cloneParam.type = CloneType(parameter.type); cloneParam.varIndex = parameter.varIndex; } if (node.returnType.HasValue()) - clone->returnType = ResolveType(node.returnType); + clone->returnType = CloneType(node.returnType); else clone->returnType = ExpressionType{ NoType{} }; if (node.depthWrite.HasValue()) - clone->depthWrite = ComputeExprValue(node.depthWrite); + ComputeExprValue(node.depthWrite, clone->depthWrite); if (node.earlyFragmentTests.HasValue()) - clone->earlyFragmentTests = ComputeExprValue(node.earlyFragmentTests); + ComputeExprValue(node.earlyFragmentTests, clone->earlyFragmentTests); if (node.entryStage.HasValue()) - clone->entryStage = ComputeExprValue(node.entryStage); + ComputeExprValue(node.entryStage, clone->entryStage); if (node.isExported.HasValue()) - clone->isExported = ComputeExprValue(node.isExported); + ComputeExprValue(node.isExported, clone->isExported); - if (clone->entryStage.HasValue()) + if (clone->entryStage.IsResultingValue()) { ShaderStageType stageType = clone->entryStage.GetResultingValue(); - if (m_context->entryFunctions[UnderlyingCast(stageType)]) - throw AstError{ "the same entry type has been defined multiple times" }; + if (!m_context->options.allowPartialSanitization) + { + if (m_context->entryFunctions[UnderlyingCast(stageType)]) + throw AstError{ "the same entry type has been defined multiple times" }; - m_context->entryFunctions[UnderlyingCast(stageType)] = &node; + m_context->entryFunctions[UnderlyingCast(stageType)] = &node; + } if (node.parameters.size() > 1) throw AstError{ "entry functions can either take one struct parameter or no parameter" }; @@ -1046,22 +1169,51 @@ namespace Nz::ShaderAst if (clone->optName.empty()) throw AstError{ "empty option name" }; - ExpressionType resolvedType = ResolveType(clone->optType); + std::optional resolvedOptionType = ResolveTypeExpr(clone->optType); + if (!resolvedOptionType) + { + clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex); + return clone; + } + + ExpressionType resolvedType = ResolveType(*resolvedOptionType); const ExpressionType& targetType = ResolveAlias(resolvedType); - if (clone->defaultValue && targetType != GetExpressionType(*clone->defaultValue)) - throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" }; + if (clone->defaultValue) + { + const ExpressionType* defaultValueType = GetExpressionType(*clone->defaultValue); + if (!defaultValueType) + { + clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex); + return clone; //< unresolved + } + + if (targetType != *defaultValueType) + throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" }; + } clone->optType = std::move(resolvedType); UInt32 optionHash = CRC32(reinterpret_cast(clone->optName.data()), clone->optName.size()); if (auto optionValueIt = m_context->options.optionValues.find(optionHash); optionValueIt != m_context->options.optionValues.end()) - clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, clone->optIndex); - else if (clone->defaultValue) - clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), clone->optIndex); + clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, node.optIndex); else - throw AstError{ "missing option " + clone->optName + " value (has no default value)" }; + { + if (m_context->options.allowPartialSanitization) + { + // Partial sanitization, we cannot give a value to this option + clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex); + } + else + { + + if (!clone->defaultValue) + throw AstError{ "missing option " + clone->optName + " value (has no default value)" }; + + clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), node.optIndex); + } + } if (m_context->options.removeOptionDeclaration) return ShaderBuilder::NoOp(); @@ -1077,33 +1229,45 @@ namespace Nz::ShaderAst auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (clone->isExported.HasValue()) - clone->isExported = ComputeExprValue(clone->isExported); + ComputeExprValue(clone->isExported); + + if (clone->description.layout.HasValue()) + ComputeExprValue(clone->description.layout); std::unordered_set declaredMembers; for (auto& member : clone->description.members) { if (member.cond.HasValue()) { - member.cond = ComputeExprValue(member.cond); - if (!member.cond.GetResultingValue()) + ComputeExprValue(member.cond); + if (member.cond.IsResultingValue() && !member.cond.GetResultingValue()) continue; } if (member.builtin.HasValue()) - member.builtin = ComputeExprValue(member.builtin); + ComputeExprValue(member.builtin); if (member.locationIndex.HasValue()) - member.locationIndex = ComputeExprValue(member.locationIndex); + ComputeExprValue(member.locationIndex); if (declaredMembers.find(member.name) != declaredMembers.end()) - throw AstError{ "struct member " + member.name + " found multiple time" }; + { + if ((!member.cond.HasValue() || !member.cond.IsResultingValue()) && !m_context->options.allowPartialSanitization) + throw AstError{ "struct member " + member.name + " found multiple time" }; + } declaredMembers.insert(member.name); - ExpressionType resolvedType = ResolveType(member.type); - if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140) + if (member.type.HasValue() && member.type.IsExpression()) { - const ExpressionType& targetType = ResolveAlias(resolvedType); + assert(m_context->options.allowPartialSanitization); + continue; + } + + ExpressionType resolvedType = member.type.GetResultingValue(); + if (clone->description.layout.IsResultingValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140) + { + const ExpressionType& targetType = ResolveAlias(member.type.GetResultingValue()); if (IsPrimitiveType(targetType) && std::get(targetType) == PrimitiveType::Boolean) throw AstError{ "boolean type is not allowed in std140 layout" }; @@ -1115,8 +1279,6 @@ namespace Nz::ShaderAst throw AstError{ "inner struct layout mismatch" }; } } - - member.type = std::move(resolvedType); } clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex); @@ -1162,35 +1324,87 @@ namespace Nz::ShaderAst auto fromExpr = CloneExpression(MandatoryExpr(node.fromExpr)); auto stepExpr = CloneExpression(node.stepExpr); auto toExpr = CloneExpression(MandatoryExpr(node.toExpr)); - MandatoryStatement(node.statement); - const ExpressionType& fromExprType = GetExpressionType(*fromExpr); - if (!IsPrimitiveType(fromExprType)) + const ExpressionType* fromExprType = GetExpressionType(*fromExpr); + const ExpressionType* toExprType = GetExpressionType(*fromExpr); + + ExpressionValue unrollValue; + + auto CloneFor = [&] + { + auto clone = std::make_unique(); + clone->fromExpr = std::move(fromExpr); + clone->stepExpr = std::move(stepExpr); + clone->toExpr = std::move(toExpr); + clone->varName = node.varName; + clone->unroll = std::move(unrollValue); + + PushScope(); + { + if (fromExprType) + clone->varIndex = RegisterVariable(node.varName, *fromExprType, node.varIndex); + else + { + RegisterUnresolved(node.varName); + clone->varIndex = node.varIndex; //< preserve var index, if set + } + clone->statement = CloneStatement(node.statement); + } + PopScope(); + + SanitizeIdentifier(clone->varName); + + return clone; + }; + + if (node.unroll.HasValue() && ComputeExprValue(node.unroll, unrollValue) == ValidationResult::Unresolved) + return CloneFor(); //< unresolved unroll + + if (!fromExprType || !toExprType) + return CloneFor(); //< unresolved from/to type + + const ExpressionType& resolvedFromExprType = ResolveAlias(*fromExprType); + if (!IsPrimitiveType(resolvedFromExprType)) throw AstError{ "numerical for from expression must be an integer or unsigned integer" }; - PrimitiveType fromType = std::get(fromExprType); - if (fromType != PrimitiveType::Int32 && fromType != PrimitiveType::UInt32) + PrimitiveType counterType = std::get(resolvedFromExprType); + if (counterType != PrimitiveType::Int32 && counterType != PrimitiveType::UInt32) throw AstError{ "numerical for from expression must be an integer or unsigned integer" }; - const ExpressionType& toExprType = GetExpressionType(*fromExpr); - if (toExprType != fromExprType) + const ExpressionType& resolvedToExprType = ResolveAlias(*toExprType); + if (resolvedToExprType != resolvedFromExprType) throw AstError{ "numerical for to expression type must match from expression type" }; if (stepExpr) { - const ExpressionType& stepExprType = GetExpressionType(*fromExpr); - if (stepExprType != fromExprType) + const ExpressionType* stepExprType = GetExpressionType(*stepExpr); + if (!stepExprType) + return CloneFor(); //< unresolved step type + + const ExpressionType& resolvedStepExprType = ResolveAlias(*stepExprType); + if (resolvedStepExprType != resolvedFromExprType) throw AstError{ "numerical for step expression type must match from expression type" }; } - - ExpressionValue unrollValue; - if (node.unroll.HasValue()) + if (unrollValue.HasValue()) { - unrollValue = ComputeExprValue(node.unroll); + assert(unrollValue.IsResultingValue()); if (unrollValue.GetResultingValue() == LoopUnroll::Always) { + std::optional fromValue = ComputeConstantValue(*fromExpr); + std::optional toValue = ComputeConstantValue(*toExpr); + if (!fromValue.has_value() || !toValue.has_value()) + return CloneFor(); //< can't resolve step value + + std::optional stepValue; + if (stepExpr) + { + stepValue = ComputeConstantValue(*stepExpr); + if (!stepValue.has_value()) + return CloneFor(); //< can't resolve step value + } + PushScope(); auto multi = std::make_unique(); @@ -1199,9 +1413,9 @@ namespace Nz::ShaderAst { using T = std::decay_t; - T counter = std::get(ComputeConstantValue(*fromExpr)); - T to = std::get(ComputeConstantValue(*toExpr)); - T step = (stepExpr) ? std::get(ComputeConstantValue(*stepExpr)) : T(1); + T counter = std::get(*fromValue); + T to = std::get(*toValue); + T step = (stepExpr) ? std::get(*stepValue) : T(1); for (; counter < to; counter += step) { @@ -1213,7 +1427,7 @@ namespace Nz::ShaderAst } }; - switch (fromType) + switch (counterType) { case PrimitiveType::Int32: Unroll(Int32{}); @@ -1271,7 +1485,7 @@ namespace Nz::ShaderAst whileStatement->unroll = std::move(unrollValue); // While condition - auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, fromType), ShaderBuilder::Variable(targetVarIndex, fromType)); + auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, counterType), ShaderBuilder::Variable(targetVarIndex, counterType)); Validate(*condition); whileStatement->condition = std::move(condition); @@ -1284,11 +1498,11 @@ namespace Nz::ShaderAst ExpressionPtr incrExpr; if (stepVarIndex) - incrExpr = ShaderBuilder::Variable(*stepVarIndex, fromType); + incrExpr = ShaderBuilder::Variable(*stepVarIndex, counterType); else - incrExpr = (fromType == PrimitiveType::Int32) ? ShaderBuilder::Constant(1) : ShaderBuilder::Constant(1u); + incrExpr = (counterType == PrimitiveType::Int32) ? ShaderBuilder::Constant(1) : ShaderBuilder::Constant(1u); - auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, fromType), std::move(incrExpr)); + auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType), std::move(incrExpr)); Validate(*incrCounter); body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); @@ -1302,25 +1516,7 @@ namespace Nz::ShaderAst return multi; } else - { - auto clone = std::make_unique(); - clone->fromExpr = std::move(fromExpr); - clone->stepExpr = std::move(stepExpr); - clone->toExpr = std::move(toExpr); - clone->varName = node.varName; - clone->unroll = std::move(unrollValue); - - PushScope(); - { - clone->varIndex = RegisterVariable(node.varName, fromExprType, node.varIndex); - clone->statement = CloneStatement(node.statement); - } - PopScope(); - - SanitizeIdentifier(clone->varName); - - return clone; - } + return CloneFor(); } StatementPtr SanitizeVisitor::Clone(ForEachStatement& node) @@ -1330,11 +1526,16 @@ namespace Nz::ShaderAst if (node.varName.empty()) throw AstError{ "for-each variable name cannot be empty"}; - const ExpressionType& exprType = GetExpressionType(*expr); + const ExpressionType* exprType = GetExpressionType(*expr); + if (!exprType) + return AstCloner::Clone(node); //< unresolved expression type + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + ExpressionType innerType; - if (IsArrayType(exprType)) + if (IsArrayType(resolvedExprType)) { - const ArrayType& arrayType = std::get(exprType); + const ArrayType& arrayType = std::get(resolvedExprType); innerType = arrayType.containedType->type; } else @@ -1343,16 +1544,18 @@ namespace Nz::ShaderAst ExpressionValue unrollValue; if (node.unroll.HasValue()) { - unrollValue = ComputeExprValue(node.unroll); + if (ComputeExprValue(node.unroll, unrollValue) == ValidationResult::Unresolved) + return AstCloner::Clone(node); //< unresolved unroll type + if (unrollValue.GetResultingValue() == LoopUnroll::Always) { PushScope(); // Repeat code auto multi = std::make_unique(); - if (IsArrayType(exprType)) + if (IsArrayType(resolvedExprType)) { - const ArrayType& arrayType = std::get(exprType); + const ArrayType& arrayType = std::get(resolvedExprType); for (UInt32 i = 0; i < arrayType.length; ++i) { @@ -1379,9 +1582,9 @@ namespace Nz::ShaderAst auto multi = std::make_unique(); - if (IsArrayType(exprType)) + if (IsArrayType(resolvedExprType)) { - const ArrayType& arrayType = std::get(exprType); + const ArrayType& arrayType = std::get(resolvedExprType); multi->statements.reserve(2); @@ -1452,7 +1655,15 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(ImportStatement& node) { if (!m_context->options.moduleResolver) + { + if (!m_context->options.allowPartialSanitization) + throw AstError{ "module " + node.moduleName + " not found" }; + + // when partially sanitizing, importing a whole module could register any identifier, so at this point we can't see unknown identifiers as errors + m_context->allowUnknownIdentifiers = true; + return StaticUniquePointerCast(AstCloner::Clone(node)); + } ModulePtr targetModule = m_context->options.moduleResolver->Resolve(node.moduleName); if (!targetModule) @@ -1487,15 +1698,21 @@ namespace Nz::ShaderAst std::string error; sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error); - if (!sanitizedModule) + if (!sanitizedModule->rootNode) throw AstError{ "module " + node.moduleName + " compilation failed: " + error }; moduleIndex = m_context->modules.size(); assert(m_context->modules.size() == moduleIndex); auto& moduleData = m_context->modules.emplace_back(); - moduleData.dependenciesVisitor = std::make_unique(); - moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode); + + // Don't run dependency checker when partially sanitizing + if (!m_context->options.allowPartialSanitization) + { + moduleData.dependenciesVisitor = std::make_unique(); + moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode); + } + moduleData.environment = std::move(moduleEnvironment); assert(m_context->currentModule->importedModules.size() == moduleIndex); @@ -1529,7 +1746,8 @@ namespace Nz::ShaderAst { assert(node.funcIndex); - moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex); + if (moduleData.dependenciesVisitor) + moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex); if (!exportedSet.usedFunctions.UnboundedTest(*node.funcIndex)) { @@ -1542,7 +1760,8 @@ namespace Nz::ShaderAst { assert(node.structIndex); - moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex); + if (moduleData.dependenciesVisitor) + moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex); if (!exportedSet.usedStructs.UnboundedTest(*node.structIndex)) { @@ -1569,6 +1788,8 @@ namespace Nz::ShaderAst for (auto& aliasPtr : aliasStatements) aliasBlock->statements.push_back(std::move(aliasPtr)); + m_context->allowUnknownIdentifiers = true; //< if module uses a unresolved and non-exported symbol, we need to allow unknown identifiers + return aliasBlock; } @@ -1607,13 +1828,12 @@ namespace Nz::ShaderAst MandatoryStatement(node.body); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); - Validate(*clone); + if (Validate(*clone) == ValidationResult::Unresolved) + return clone; - ExpressionValue unrollValue; - if (node.unroll.HasValue()) + if (clone->unroll.HasValue()) { - clone->unroll = ComputeExprValue(node.unroll); - if (clone->unroll.GetResultingValue() == LoopUnroll::Always) + if (ComputeExprValue(clone->unroll) == ValidationResult::Validated && clone->unroll.GetResultingValue() == LoopUnroll::Always) throw AstError{ "unroll(always) is not yet supported on while" }; } @@ -1669,55 +1889,6 @@ namespace Nz::ShaderAst return &it->data; } - TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const - { - const auto* identifier = FindIdentifier(identifierName); - if (!identifier) - throw std::runtime_error("identifier " + std::string(identifierName) + " not found"); - - switch (identifier->category) - { - case IdentifierCategory::Constant: - return m_context->constantValues.Retrieve(identifier->index); - - case IdentifierCategory::Struct: - return StructType{ identifier->index }; - - case IdentifierCategory::Type: - return std::visit([&](auto&& arg) -> TypeParameter - { - return arg; - }, m_context->types.Retrieve(identifier->index)); - - case IdentifierCategory::Alias: - { - IdentifierCategory category; - std::size_t index; - do - { - const auto& aliasData = m_context->aliases.Retrieve(identifier->index); - category = aliasData.category; - index = aliasData.index; - } - while (category == IdentifierCategory::Alias); - } - - case IdentifierCategory::Function: - throw std::runtime_error("unexpected function identifier"); - - case IdentifierCategory::Intrinsic: - throw std::runtime_error("unexpected intrinsic identifier"); - - case IdentifierCategory::Module: - throw std::runtime_error("unexpected module identifier"); - - case IdentifierCategory::Variable: - throw std::runtime_error("unexpected variable identifier"); - } - - throw std::runtime_error("internal error"); - } - ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData) { switch (identifierData->category) @@ -1761,6 +1932,9 @@ namespace Nz::ShaderAst return intrinsicExpr; } + case IdentifierCategory::Module: + throw AstError{ "unexpected module identifier" }; + case IdentifierCategory::Struct: { // Replace IdentifierExpression by StructTypeExpression @@ -1773,12 +1947,16 @@ namespace Nz::ShaderAst case IdentifierCategory::Type: { - auto clone = ShaderBuilder::Identifier("dummy"); - clone->cachedExpressionType = Type{ identifierData->index }; + auto typeExpr = std::make_unique(); + typeExpr->cachedExpressionType = Type{ identifierData->index }; + typeExpr->typeId = identifierData->index; - return clone; + return typeExpr; } + case IdentifierCategory::Unresolved: + throw AstError{ "unexpected unresolved identifier" }; + case IdentifierCategory::Variable: { // Replace IdentifierExpression by VariableExpression @@ -1788,10 +1966,30 @@ namespace Nz::ShaderAst return varExpr; } - - default: - throw AstError{ "unexpected identifier" }; } + + throw AstError{ "internal error" }; + } + + const ExpressionType* SanitizeVisitor::GetExpressionType(Expression& expr) const + { + const ExpressionType* expressionType = ShaderAst::GetExpressionType(expr); + if (!expressionType) + { + if (!m_context->options.allowPartialSanitization) + throw AstError{ "unexpected missing expression type" }; //< InternalError + } + + return expressionType; + } + + const ExpressionType& SanitizeVisitor::GetExpressionTypeSecure(Expression& expr) const + { + const ExpressionType* expressionType = GetExpressionType(expr); + if (!expressionType) + throw AstError{ "unexpected missing expression type" }; //< InternalError + + return *expressionType; } Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const @@ -1843,53 +2041,104 @@ namespace Nz::ShaderAst return varExpr; } - ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr) const + std::optional SanitizeVisitor::ComputeConstantValue(Expression& expr) const { // Run optimizer on constant value to hopefully retrieve a single constant value ExpressionPtr optimizedExpr = PropagateConstants(expr); if (optimizedExpr->GetType() != NodeType::ConstantValueExpression) - throw AstError{"expected a constant expression"}; + { + if (!m_context->options.allowPartialSanitization) + throw AstError{ "expected a constant expression" }; + + return std::nullopt; + } return static_cast(*optimizedExpr).value; } template - const T& SanitizeVisitor::ComputeExprValue(ExpressionValue& attribute) const + auto SanitizeVisitor::ComputeExprValue(ExpressionValue& attribute) const -> ValidationResult { if (!attribute.HasValue()) throw AstError{ "attribute expected a value" }; if (attribute.IsExpression()) { - ConstantValue value = ComputeConstantValue(*attribute.GetExpression()); + std::optional value = ComputeConstantValue(*attribute.GetExpression()); + if (!value) + return ValidationResult::Unresolved; + if constexpr (TypeListFind) { - if (!std::holds_alternative(value)) + if (!std::holds_alternative(*value)) { // HAAAAAX - if (std::holds_alternative(value) && std::is_same_v) - attribute = static_cast(std::get(value)); + if (std::holds_alternative(*value) && std::is_same_v) + attribute = static_cast(std::get(*value)); else throw AstError{ "unexpected attribute type" }; } else - attribute = std::get(value); + attribute = std::get(*value); } else throw AstError{ "unexpected expression for this type" }; } - assert(attribute.IsResultingValue()); - return attribute.GetResultingValue(); + return ValidationResult::Validated; + } + + template + auto SanitizeVisitor::ComputeExprValue(const ExpressionValue& attribute, ExpressionValue& targetAttribute) -> ValidationResult + { + if (!attribute.HasValue()) + throw AstError{ "attribute expected a value" }; + + if (attribute.IsExpression()) + { + std::optional value = ComputeConstantValue(*attribute.GetExpression()); + if (!value) + { + targetAttribute = AstCloner::Clone(*attribute.GetExpression()); + return ValidationResult::Unresolved; + } + + if constexpr (TypeListFind) + { + if (!std::holds_alternative(*value)) + { + // HAAAAAX + if (std::holds_alternative(*value) && std::is_same_v) + targetAttribute = static_cast(std::get(*value)); + else + throw AstError{ "unexpected attribute type" }; + } + else + targetAttribute = std::get(*value); + } + else + throw AstError{ "unexpected expression for this type" }; + } + else + { + assert(attribute.IsResultingValue()); + targetAttribute = attribute.GetResultingValue(); + } + + return ValidationResult::Validated; } template std::unique_ptr SanitizeVisitor::PropagateConstants(T& node) const { AstConstantPropagationVisitor::Options optimizerOptions; - optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue& + optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue* { - return m_context->constantValues.Retrieve(constantId); + const ConstantValue* value = m_context->constantValues.TryRetrieve(constantId); + if (!value && !m_context->options.allowPartialSanitization) + throw AstError{ "invalid constant index #" + std::to_string(constantId) }; + + return value; }; // Run optimizer on constant value to hopefully retrieve a single constant value @@ -2090,12 +2339,22 @@ namespace Nz::ShaderAst RegisterIntrinsic("reflect", IntrinsicType::Reflect); } - std::size_t SanitizeVisitor::RegisterAlias(std::string name, IdentifierData aliasData, std::optional index) + std::size_t SanitizeVisitor::RegisterAlias(std::string name, std::optional aliasData, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t aliasIndex = m_context->aliases.Register(std::move(aliasData), index); + std::size_t aliasIndex; + if (aliasData) + aliasIndex = m_context->aliases.Register(std::move(*aliasData), index); + else if (index) + { + m_context->aliases.PreregisterIndex(*index); + aliasIndex = *index; + } + else + aliasIndex = m_context->aliases.RegisterNewIndex(true); + m_context->currentEnv->identifiersInScope.push_back({ std::move(name), aliasIndex, @@ -2105,12 +2364,22 @@ namespace Nz::ShaderAst return aliasIndex; } - std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, std::optional index) + std::size_t SanitizeVisitor::RegisterConstant(std::string name, std::optional value, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t constantIndex = m_context->constantValues.Register(std::move(value), index); + std::size_t constantIndex; + if (value) + constantIndex = m_context->constantValues.Register(std::move(*value), index); + else if (index) + { + m_context->constantValues.PreregisterIndex(*index); + constantIndex = *index; + } + else + constantIndex = m_context->constantValues.RegisterNewIndex(true); + m_context->currentEnv->identifiersInScope.push_back({ std::move(name), constantIndex, @@ -2120,25 +2389,45 @@ namespace Nz::ShaderAst return constantIndex; } - std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData, std::optional index) + std::size_t SanitizeVisitor::RegisterFunction(std::string name, std::optional funcData, std::optional index) { if (auto* identifier = FindIdentifier(name)) { - bool duplicate = true; + // Functions can be conditionally defined and condition not resolved yet, allow duplicates when partially sanitizing + bool duplicate = !m_context->options.allowPartialSanitization; // Functions cannot be declared twice, except for entry ones if their stages are different - if (funcData.node->entryStage.HasValue() && identifier->category == IdentifierCategory::Function) + if (funcData) { - auto& otherFunction = m_context->functions.Retrieve(identifier->index); - if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) - duplicate = false; + if (funcData->node->entryStage.HasValue() && identifier->category == IdentifierCategory::Function) + { + auto& otherFunction = m_context->functions.Retrieve(identifier->index); + if (funcData->node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) + duplicate = false; + } + } + else + { + if (!m_context->options.allowPartialSanitization) + throw AstError{ "internal error" }; + + duplicate = false; } if (duplicate) throw AstError{ name + " is already used" }; } - std::size_t functionIndex = m_context->functions.Register(std::move(funcData), index); + std::size_t functionIndex; + if (funcData) + functionIndex = m_context->functions.Register(std::move(*funcData), index); + else if (index) + { + m_context->functions.PreregisterIndex(*index); + functionIndex = *index; + } + else + functionIndex = m_context->functions.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), @@ -2181,12 +2470,21 @@ namespace Nz::ShaderAst return moduleIndex; } - std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional index) + std::size_t SanitizeVisitor::RegisterStruct(std::string name, std::optional description, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t structIndex = m_context->structs.Register(description, index); + std::size_t structIndex; + if (description) + structIndex = m_context->structs.Register(*description, index); + else if (index) + { + m_context->structs.PreregisterIndex(*index); + structIndex = *index; + } + else + structIndex = m_context->structs.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), @@ -2197,12 +2495,21 @@ namespace Nz::ShaderAst return structIndex; } - std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType, std::optional index) + std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional expressionType, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t typeIndex = m_context->types.Register(std::move(expressionType), index); + std::size_t typeIndex; + if (expressionType) + typeIndex = m_context->types.Register(std::move(*expressionType), index); + else if (index) + { + m_context->types.PreregisterIndex(*index); + typeIndex = *index; + } + else + typeIndex = m_context->types.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), @@ -2213,12 +2520,21 @@ namespace Nz::ShaderAst return typeIndex; } - std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType, std::optional index) + std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional partialType, std::optional index) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t typeIndex = m_context->types.Register(std::move(partialType), index); + std::size_t typeIndex; + if (partialType) + typeIndex = m_context->types.Register(std::move(*partialType), index); + else if (index) + { + m_context->types.PreregisterIndex(*index); + typeIndex = *index; + } + else + typeIndex = m_context->types.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), @@ -2229,7 +2545,16 @@ namespace Nz::ShaderAst return typeIndex; } - std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type, std::optional index) + void SanitizeVisitor::RegisterUnresolved(std::string name) + { + m_context->currentEnv->identifiersInScope.push_back({ + std::move(name), + std::numeric_limits::max(), + IdentifierCategory::Unresolved + }); + } + + std::size_t SanitizeVisitor::RegisterVariable(std::string name, std::optional type, std::optional index) { if (auto* identifier = FindIdentifier(name)) { @@ -2238,7 +2563,16 @@ namespace Nz::ShaderAst throw AstError{ name + " is already used" }; } - std::size_t varIndex = m_context->variableTypes.Register(std::move(type), index); + std::size_t varIndex; + if (type) + varIndex = m_context->variableTypes.Register(std::move(*type), index); + else if (index) + { + m_context->variableTypes.PreregisterIndex(*index); + varIndex = *index; + } + else + varIndex = m_context->variableTypes.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), @@ -2311,23 +2645,6 @@ namespace Nz::ShaderAst } } - const ExpressionPtr& SanitizeVisitor::ResolveCondExpression(ConditionalExpression& node) - { - MandatoryExpr(node.condition); - MandatoryExpr(node.truePath); - MandatoryExpr(node.falsePath); - - ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition)); - if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean }) - throw AstError{ "expected a boolean value" }; - - if (std::get(conditionValue)) - return node.truePath; - else - return node.falsePath; - - } - std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType) { return ResolveStruct(aliasType.targetType->type); @@ -2400,23 +2717,24 @@ namespace Nz::ShaderAst return std::get(type); } - ExpressionType SanitizeVisitor::ResolveType(const ExpressionValue& exprTypeValue, bool resolveAlias) + std::optional SanitizeVisitor::ResolveTypeExpr(const ExpressionValue& exprTypeValue, bool resolveAlias) { if (!exprTypeValue.HasValue()) - return {}; + return NoType{}; if (exprTypeValue.IsResultingValue()) return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias); assert(exprTypeValue.IsExpression()); ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression()); - assert(expression->cachedExpressionType); + const ExpressionType* exprType = GetExpressionType(*expression); + if (!exprType) + return std::nullopt; - const ExpressionType& exprType = expression->cachedExpressionType.value(); //if (!IsTypeType(exprType)) // throw AstError{ "type expected" }; - return ResolveType(exprType, resolveAlias); + return ResolveType(*exprType, resolveAlias); } void SanitizeVisitor::SanitizeIdentifier(std::string& identifier) @@ -2462,9 +2780,15 @@ namespace Nz::ShaderAst return output; } - void SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const + auto SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const -> ValidationResult { - return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right)); + const ExpressionType* leftType = GetExpressionType(*left); + const ExpressionType* rightType = GetExpressionType(*right); + if (!leftType || !rightType) + return ValidationResult::Unresolved; + + TypeMustMatch(*leftType, *rightType); + return ValidationResult::Validated; } void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const @@ -2483,43 +2807,65 @@ namespace Nz::ShaderAst return node; } - void SanitizeVisitor::Validate(DeclareAliasStatement& node) + auto SanitizeVisitor::Validate(DeclareAliasStatement& node) -> ValidationResult { if (node.name.empty()) throw std::runtime_error("invalid alias name"); - ExpressionType exprType = GetExpressionType(*node.expression); - if (IsStructType(exprType)) + const ExpressionType* exprType = GetExpressionType(*node.expression); + if (!exprType) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*exprType); + + IdentifierData targetIdentifier; + if (IsStructType(resolvedType)) { - std::size_t structIndex = ResolveStruct(exprType); - node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex); + std::size_t structIndex = ResolveStruct(resolvedType); + targetIdentifier = { structIndex, IdentifierCategory::Struct }; } - else if (IsFunctionType(exprType)) + else if (IsFunctionType(resolvedType)) { - std::size_t funcIndex = std::get(exprType).funcIndex; - node.aliasIndex = RegisterAlias(node.name, { funcIndex, IdentifierCategory::Function }, node.aliasIndex); + std::size_t funcIndex = std::get(resolvedType).funcIndex; + targetIdentifier = { funcIndex, IdentifierCategory::Function }; } - else if (IsAliasType(exprType)) + else if (IsAliasType(resolvedType)) { - const AliasType& alias = std::get(exprType); - node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex); + const AliasType& alias = std::get(resolvedType); + targetIdentifier = { alias.aliasIndex, IdentifierCategory::Alias }; } else throw AstError{ "for now, only aliases, functions and structs can be aliased" }; + + node.aliasIndex = RegisterAlias(node.name, targetIdentifier, node.aliasIndex); + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(WhileStatement& node) + auto SanitizeVisitor::Validate(WhileStatement& node) -> ValidationResult { - if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean }) + const ExpressionType* conditionType = GetExpressionType(MandatoryExpr(node.condition)); + MandatoryStatement(node.body); + + if (!conditionType) + return ValidationResult::Unresolved; + + if (ResolveAlias(*conditionType) != ExpressionType{ PrimitiveType::Boolean }) throw AstError{ "expected a boolean value" }; + + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(AccessIndexExpression& node) + auto SanitizeVisitor::Validate(AccessIndexExpression& node) -> ValidationResult { - ExpressionType exprType = GetExpressionType(*node.expr); - if (IsTypeExpression(exprType)) + const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expr)); + if (!exprType) + return ValidationResult::Unresolved; + + ExpressionType resolvedExprType = ResolveAlias(*exprType); + + if (IsTypeExpression(resolvedExprType)) { - std::size_t typeIndex = std::get(exprType).typeIndex; + std::size_t typeIndex = std::get(resolvedExprType).typeIndex; const auto& type = m_context->types.Retrieve(typeIndex); if (!std::holds_alternative(type)) @@ -2537,7 +2883,11 @@ namespace Nz::ShaderAst { case TypeParameterCategory::ConstantValue: { - parameters.push_back(ComputeConstantValue(*indexExpr)); + std::optional value = ComputeConstantValue(*indexExpr); + if (!value.has_value()) + return ValidationResult::Unresolved; + + parameters.push_back(std::move(*value)); break; } @@ -2545,7 +2895,11 @@ namespace Nz::ShaderAst case TypeParameterCategory::PrimitiveType: case TypeParameterCategory::StructType: { - ExpressionType resolvedType = ResolveType(GetExpressionType(*indexExpr), true); + const ExpressionType* indexExprType = GetExpressionType(*indexExpr); + if (!indexExprType) + return ValidationResult::Unresolved; + + ExpressionType resolvedType = ResolveType(*indexExprType, true); switch (partialType.parameters[i]) { @@ -2583,65 +2937,81 @@ namespace Nz::ShaderAst if (node.indices.size() != 1) throw AstError{ "AccessIndexExpression must have at one index" }; - for (auto& index : node.indices) - { - const ExpressionType& indexType = GetExpressionType(*index); - if (!IsPrimitiveType(indexType)) - throw AstError{ "AccessIndex expects integer indices" }; - - PrimitiveType primitiveIndexType = std::get(indexType); - if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32) - throw AstError{ "AccessIndex expects integer indices" }; - } - for (const auto& indexExpr : node.indices) { - if (IsArrayType(exprType)) + const ExpressionType* indexType = GetExpressionType(*indexExpr); + if (!indexType) + return ValidationResult::Unresolved; + + if (!IsPrimitiveType(*indexType)) + throw AstError{ "AccessIndex expects integer indices" }; + + PrimitiveType primitiveIndexType = std::get(*indexType); + if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32) + throw AstError{ "AccessIndex expects integer indices" }; + + if (IsArrayType(resolvedExprType)) { - const ArrayType& arrayType = std::get(exprType); + const ArrayType& arrayType = std::get(resolvedExprType); ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType - exprType = std::move(containedType); + resolvedExprType = std::move(containedType); } - else if (IsStructType(exprType)) + else if (IsStructType(resolvedExprType)) { - const ExpressionType& indexType = GetExpressionType(*indexExpr); - if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) + if (primitiveIndexType != PrimitiveType::Int32) throw AstError{ "struct can only be accessed with constant i32 indices" }; ConstantValueExpression& constantExpr = static_cast(*indexExpr); Int32 index = std::get(constantExpr.value); - std::size_t structIndex = ResolveStruct(exprType); + std::size_t structIndex = ResolveStruct(resolvedExprType); const StructDescription* s = m_context->structs.Retrieve(structIndex); - exprType = ResolveType(s->members[index].type, true); + std::optional resolvedExprTypeOpt = ResolveTypeExpr(s->members[index].type, true); + if (!resolvedExprTypeOpt.has_value()) + return ValidationResult::Unresolved; + + resolvedExprType = std::move(resolvedExprTypeOpt).value(); } - else if (IsMatrixType(exprType)) + else if (IsMatrixType(resolvedExprType)) { // Matrix index (ex: mat[2]) - MatrixType matrixType = std::get(exprType); + MatrixType matrixType = std::get(resolvedExprType); //TODO: Handle row-major matrices - exprType = VectorType{ matrixType.rowCount, matrixType.type }; + resolvedExprType = VectorType{ matrixType.rowCount, matrixType.type }; } - else if (IsVectorType(exprType)) + else if (IsVectorType(resolvedExprType)) { // Swizzle expression with one component (ex: vec[2]) - VectorType swizzledVec = std::get(exprType); + VectorType swizzledVec = std::get(resolvedExprType); - exprType = swizzledVec.type; + resolvedExprType = swizzledVec.type; } else throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; //< TODO: Add support for arrays } - node.cachedExpressionType = std::move(exprType); + node.cachedExpressionType = std::move(resolvedExprType); } + + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(AssignExpression& node) + auto SanitizeVisitor::Validate(AssignExpression& node) -> ValidationResult { + MandatoryExpr(node.left); + MandatoryExpr(node.right); + + const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left)); + if (!leftExprType) + return ValidationResult::Unresolved; + + const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right)); + if (!rightExprType) + return ValidationResult::Unresolved; + if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) throw AstError{ "Assignation is only possible with a l-value" }; @@ -2649,7 +3019,9 @@ namespace Nz::ShaderAst switch (node.op) { case AssignType::Simple: - TypeMustMatch(node.left, node.right); + if (TypeMustMatch(node.left, node.right) == ValidationResult::Unresolved) + return ValidationResult::Unresolved; + break; case AssignType::CompoundAdd: binaryType = BinaryType::Add; break; @@ -2662,8 +3034,8 @@ namespace Nz::ShaderAst if (binaryType) { - ExpressionType expressionType = ValidateBinaryOp(*binaryType, node.left, node.right); - TypeMustMatch(GetExpressionType(*node.left), expressionType); + ExpressionType expressionType = ValidateBinaryOp(*binaryType, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType)); + TypeMustMatch(*leftExprType, expressionType); if (m_context->options.removeCompoundAssignments) { @@ -2673,15 +3045,28 @@ namespace Nz::ShaderAst } } - node.cachedExpressionType = GetExpressionType(*node.left); + node.cachedExpressionType = *leftExprType; + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(BinaryExpression& node) + auto SanitizeVisitor::Validate(BinaryExpression& node) -> ValidationResult { - node.cachedExpressionType = ValidateBinaryOp(node.op, node.left, node.right); + MandatoryExpr(node.left); + MandatoryExpr(node.right); + + const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left)); + if (!leftExprType) + return ValidationResult::Unresolved; + + const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right)); + if (!rightExprType) + return ValidationResult::Unresolved; + + node.cachedExpressionType = ValidateBinaryOp(node.op, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType)); + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(CallFunctionExpression& node) + auto SanitizeVisitor::Validate(CallFunctionExpression& node) -> ValidationResult { std::size_t targetFuncIndex; if (node.targetFunction->GetType() == NodeType::FunctionExpression) @@ -2708,7 +3093,11 @@ namespace Nz::ShaderAst for (std::size_t i = 0; i < node.parameters.size(); ++i) { - if (GetExpressionType(*node.parameters[i]) != referenceDeclaration->parameters[i].type.GetResultingValue()) + const ExpressionType* parameterType = GetExpressionType(*node.parameters[i]); + if (!parameterType) + return ValidationResult::Unresolved; + + if (ResolveAlias(*parameterType) != ResolveAlias(referenceDeclaration->parameters[i].type.GetResultingValue())) throw AstError{ "function " + referenceDeclaration->name + " parameter " + std::to_string(i) + " type mismatch" }; } @@ -2716,12 +3105,16 @@ namespace Nz::ShaderAst throw AstError{ "function " + referenceDeclaration->name + " expected " + std::to_string(referenceDeclaration->parameters.size()) + " parameters, got " + std::to_string(node.parameters.size()) }; node.cachedExpressionType = referenceDeclaration->returnType.GetResultingValue(); + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(CastExpression& node) + auto SanitizeVisitor::Validate(CastExpression& node) -> ValidationResult { - ExpressionType resolvedType = ResolveType(node.targetType); - const ExpressionType& targetType = ResolveAlias(resolvedType); + std::optional targetTypeOpt = ResolveTypeExpr(node.targetType); + if (!targetTypeOpt) + return ValidationResult::Unresolved; + + const ExpressionType& targetType = ResolveAlias(*targetTypeOpt); const auto& firstExprPtr = node.expressions.front(); if (!firstExprPtr) @@ -2731,8 +3124,11 @@ namespace Nz::ShaderAst { const MatrixType& targetMatrixType = std::get(targetType); - const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr); - if (IsMatrixType(firstExprType)) + const ExpressionType* firstExprType = GetExpressionType(*firstExprPtr); + if (!firstExprType) + return ValidationResult::Unresolved; + + if (IsMatrixType(ResolveAlias(*firstExprType))) { if (node.expressions[1]) throw AstError{ "too many expressions" }; @@ -2748,11 +3144,15 @@ namespace Nz::ShaderAst if (!exprPtr) throw AstError{ "component count doesn't match required component count" }; - const ExpressionType& exprType = GetExpressionType(*exprPtr); - if (!IsVectorType(exprType)) + const ExpressionType* exprType = GetExpressionType(*exprPtr); + if (!exprType) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + if (!IsVectorType(resolvedExprType)) throw AstError{ "expected vector type" }; - const VectorType& vecType = std::get(exprType); + const VectorType& vecType = std::get(resolvedExprType); if (vecType.componentCount != targetMatrixType.rowCount) throw AstError{ "vector component count must match target matrix row count" }; } @@ -2779,22 +3179,28 @@ namespace Nz::ShaderAst if (!exprPtr) break; - const ExpressionType& exprType = GetExpressionType(*exprPtr); - if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) + const ExpressionType* exprType = GetExpressionType(*exprPtr); + if (!exprType) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + if (!IsPrimitiveType(resolvedExprType) && !IsVectorType(resolvedExprType)) throw AstError{ "incompatible type" }; - componentCount += GetComponentCount(exprType); + componentCount += GetComponentCount(resolvedExprType); } if (componentCount != requiredComponents) throw AstError{ "component count doesn't match required component count" }; } - node.cachedExpressionType = resolvedType; - node.targetType = std::move(resolvedType); + node.cachedExpressionType = targetType; + node.targetType = std::move(targetType); + + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(DeclareVariableStatement& node) + auto SanitizeVisitor::Validate(DeclareVariableStatement& node) -> ValidationResult { ExpressionType resolvedType; if (!node.varType.HasValue()) @@ -2802,13 +3208,36 @@ namespace Nz::ShaderAst if (!node.initialExpression) throw AstError{ "variable must either have a type or an initial value" }; - resolvedType = GetExpressionType(*node.initialExpression); + const ExpressionType* initialExprType = GetExpressionType(*node.initialExpression); + if (!initialExprType) + { + RegisterUnresolved(node.varName); + return ValidationResult::Unresolved; + } + + resolvedType = *initialExprType; } else { - resolvedType = ResolveType(node.varType); + std::optional varType = ResolveTypeExpr(node.varType); + if (!varType) + { + RegisterUnresolved(node.varName); + return ValidationResult::Unresolved; + } + + resolvedType = std::move(varType).value(); if (node.initialExpression) - TypeMustMatch(resolvedType, GetExpressionType(*node.initialExpression)); + { + const ExpressionType* initialExprType = GetExpressionType(*node.initialExpression); + if (!initialExprType) + { + RegisterUnresolved(node.varName); + return ValidationResult::Unresolved; + } + + TypeMustMatch(resolvedType, *initialExprType); + } } node.varIndex = RegisterVariable(node.varName, resolvedType, node.varIndex); @@ -2841,9 +3270,10 @@ namespace Nz::ShaderAst } SanitizeIdentifier(node.varName); + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(IntrinsicExpression& node) + auto SanitizeVisitor::Validate(IntrinsicExpression& node) -> ValidationResult { // Parameter validation switch (node.intrinsic) @@ -2861,11 +3291,17 @@ namespace Nz::ShaderAst for (auto& param : node.parameters) MandatoryExpr(param); - const ExpressionType& type = GetExpressionType(*node.parameters.front()); + const ExpressionType* firstParameterType = GetExpressionType(*node.parameters.front()); + if (!firstParameterType) + return ValidationResult::Unresolved; for (std::size_t i = 1; i < node.parameters.size(); ++i) { - if (type != GetExpressionType(*node.parameters[i])) + const ExpressionType* parameterType = GetExpressionType(*node.parameters[i]); + if (!parameterType) + return ValidationResult::Unresolved; + + if (ResolveAlias(*firstParameterType) != ResolveAlias(*parameterType)) throw AstError{ "All type must match" }; } @@ -2887,8 +3323,12 @@ namespace Nz::ShaderAst if (node.parameters.size() != 1) throw AstError{ "Expected only one parameters" }; - const ExpressionType& type = GetExpressionType(MandatoryExpr(node.parameters.front())); - if (!IsVectorType(type)) + const ExpressionType* type = GetExpressionType(MandatoryExpr(node.parameters.front())); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + if (!IsVectorType(resolvedType)) throw AstError{ "Expected a vector" }; break; @@ -2902,10 +3342,18 @@ namespace Nz::ShaderAst for (auto& param : node.parameters) MandatoryExpr(param); - if (!IsSamplerType(GetExpressionType(*node.parameters[0]))) + const ExpressionType* firstParameterType = GetExpressionType(*node.parameters[0]); + if (!firstParameterType) + return ValidationResult::Unresolved; + + if (!IsSamplerType(*firstParameterType)) throw AstError{ "First parameter must be a sampler" }; - if (!IsVectorType(GetExpressionType(*node.parameters[1]))) + const ExpressionType* secondParameterType = GetExpressionType(*node.parameters[1]); + if (!secondParameterType) + return ValidationResult::Unresolved; + + if (!IsVectorType(*secondParameterType)) throw AstError{ "Second parameter must be a vector" }; break; @@ -2917,94 +3365,125 @@ namespace Nz::ShaderAst { case IntrinsicType::CrossProduct: { - const ExpressionType& type = GetExpressionType(*node.parameters.front()); - if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) + const ExpressionType* type = GetExpressionType(*node.parameters.front()); + if (!type) + return ValidationResult::Unresolved; + + if (ResolveAlias(*type) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) throw AstError{ "CrossProduct only works with vec3[f32] expressions" }; - node.cachedExpressionType = type; + node.cachedExpressionType = *type; break; } case IntrinsicType::DotProduct: case IntrinsicType::Length: { - const ExpressionType& type = GetExpressionType(*node.parameters.front()); - if (!IsVectorType(type)) + const ExpressionType* type = GetExpressionType(*node.parameters.front()); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + if (!IsVectorType(resolvedType)) throw AstError{ "DotProduct expects vector types" }; //< FIXME - node.cachedExpressionType = std::get(type).type; + node.cachedExpressionType = std::get(resolvedType).type; break; } case IntrinsicType::Normalize: case IntrinsicType::Reflect: { - const ExpressionType& type = GetExpressionType(*node.parameters.front()); - if (!IsVectorType(type)) + const ExpressionType* type = GetExpressionType(*node.parameters.front()); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + if (!IsVectorType(resolvedType)) throw AstError{ "DotProduct expects vector types" }; //< FIXME - node.cachedExpressionType = type; + node.cachedExpressionType = *type; break; } case IntrinsicType::Max: case IntrinsicType::Min: { - const ExpressionType& type = GetExpressionType(*node.parameters.front()); - if (!IsPrimitiveType(type) && !IsVectorType(type)) + const ExpressionType* type = GetExpressionType(*node.parameters.front()); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + if (!IsPrimitiveType(resolvedType) && !IsVectorType(resolvedType)) throw AstError{ "max and min only work with primitive and vector types" }; - if ((IsPrimitiveType(type) && std::get(type) == PrimitiveType::Boolean) || - (IsVectorType(type) && std::get(type).type == PrimitiveType::Boolean)) + if ((IsPrimitiveType(resolvedType) && std::get(resolvedType) == PrimitiveType::Boolean) || + (IsVectorType(resolvedType) && std::get(resolvedType).type == PrimitiveType::Boolean)) throw AstError{ "max and min do not work with booleans" }; - node.cachedExpressionType = type; + node.cachedExpressionType = *type; break; } case IntrinsicType::Exp: case IntrinsicType::Pow: { - const ExpressionType& type = GetExpressionType(*node.parameters.front()); - if (!IsPrimitiveType(type) && !IsVectorType(type)) + const ExpressionType* type = GetExpressionType(*node.parameters.front()); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + if (!IsPrimitiveType(resolvedType) && !IsVectorType(resolvedType)) throw AstError{ "pow only works with primitive and vector types" }; - if ((IsPrimitiveType(type) && std::get(type) != PrimitiveType::Float32) || - (IsVectorType(type) && std::get(type).type != PrimitiveType::Float32)) + if ((IsPrimitiveType(resolvedType) && std::get(resolvedType) != PrimitiveType::Float32) || + (IsVectorType(resolvedType) && std::get(resolvedType).type != PrimitiveType::Float32)) throw AstError{ "pow only works with floating-point primitive or vectors" }; - node.cachedExpressionType = type; + node.cachedExpressionType = *type; break; } case IntrinsicType::SampleTexture: { - node.cachedExpressionType = VectorType{ 4, std::get(GetExpressionType(*node.parameters.front())).sampledType }; + const ExpressionType* type = GetExpressionType(*node.parameters.front()); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + node.cachedExpressionType = VectorType{ 4, std::get(resolvedType).sampledType }; break; } } + + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(SwizzleExpression& node) + auto SanitizeVisitor::Validate(SwizzleExpression& node) -> ValidationResult { MandatoryExpr(node.expression); - const ExpressionType& exprType = GetExpressionType(*node.expression); - if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) + const ExpressionType* exprType = GetExpressionType(*node.expression); + if (!exprType) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + + if (!IsPrimitiveType(resolvedExprType) && !IsVectorType(resolvedExprType)) throw AstError{ "Cannot swizzle this type" }; PrimitiveType baseType; std::size_t componentCount; - if (IsPrimitiveType(exprType)) + if (IsPrimitiveType(resolvedExprType)) { if (m_context->options.removeScalarSwizzling) throw AstError{ "internal error" }; //< scalar swizzling should have been removed by then - baseType = std::get(exprType); + baseType = std::get(resolvedExprType); componentCount = 1; } else { - const VectorType& vecType = std::get(exprType); + const VectorType& vecType = std::get(resolvedExprType); baseType = vecType.type; componentCount = vecType.componentCount; } @@ -3027,17 +3506,23 @@ namespace Nz::ShaderAst } else node.cachedExpressionType = baseType; + + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(UnaryExpression& node) + auto SanitizeVisitor::Validate(UnaryExpression& node) -> ValidationResult { - const ExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression)); + const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expression)); + if (!exprType) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); switch (node.op) { case UnaryType::LogicalNot: { - if (exprType != ExpressionType(PrimitiveType::Boolean)) + if (resolvedExprType != ExpressionType(PrimitiveType::Boolean)) throw AstError{ "logical not is only supported on booleans" }; break; @@ -3047,10 +3532,10 @@ namespace Nz::ShaderAst case UnaryType::Plus: { PrimitiveType basicType; - if (IsPrimitiveType(exprType)) - basicType = std::get(exprType); - else if (IsVectorType(exprType)) - basicType = std::get(exprType).type; + if (IsPrimitiveType(resolvedExprType)) + basicType = std::get(resolvedExprType); + else if (IsVectorType(resolvedExprType)) + basicType = std::get(resolvedExprType).type; else throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" }; @@ -3061,19 +3546,18 @@ namespace Nz::ShaderAst } } - node.cachedExpressionType = exprType; + node.cachedExpressionType = *exprType; + return ValidationResult::Validated; } - void SanitizeVisitor::Validate(VariableValueExpression& node) + auto SanitizeVisitor::Validate(VariableValueExpression& node) -> ValidationResult { node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId); + return ValidationResult::Validated; } - ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr) + ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType) { - const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(leftExpr)); - const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(rightExpr)); - if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; @@ -3096,13 +3580,13 @@ namespace Nz::ShaderAst case BinaryType::CompEq: case BinaryType::CompNe: { - TypeMustMatch(leftExpr, rightExpr); + TypeMustMatch(leftExprType, rightExprType); return PrimitiveType::Boolean; } case BinaryType::Add: case BinaryType::Subtract: - TypeMustMatch(leftExpr, rightExpr); + TypeMustMatch(leftExprType, rightExprType); return leftExprType; case BinaryType::Multiply: @@ -3149,7 +3633,7 @@ namespace Nz::ShaderAst if (leftType != PrimitiveType::Boolean) throw AstError{ "logical and/or are only supported on booleans" }; - TypeMustMatch(leftExpr, rightExpr); + TypeMustMatch(leftExprType, rightExprType); return PrimitiveType::Boolean; } } @@ -3165,12 +3649,12 @@ namespace Nz::ShaderAst case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: - TypeMustMatch(leftExpr, rightExpr); + TypeMustMatch(leftExprType, rightExprType); return PrimitiveType::Boolean; case BinaryType::Add: case BinaryType::Subtract: - TypeMustMatch(leftExpr, rightExpr); + TypeMustMatch(leftExprType, rightExprType); return leftExprType; case BinaryType::Multiply: @@ -3216,12 +3700,12 @@ namespace Nz::ShaderAst case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: - TypeMustMatch(leftExpr, rightExpr); + TypeMustMatch(leftExprType, rightExprType); return PrimitiveType::Boolean; case BinaryType::Add: case BinaryType::Subtract: - TypeMustMatch(leftExpr, rightExpr); + TypeMustMatch(leftExprType, rightExprType); return leftExprType; case BinaryType::Multiply: diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 03ed9d909..890ab1519 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -36,7 +36,7 @@ namespace Nz AstRecursiveVisitor::Visit(node); assert(currentFunction); - currentFunction->calledFunctions.UnboundedSet(std::get(GetExpressionType(*node.targetFunction)).funcIndex); + currentFunction->calledFunctions.UnboundedSet(std::get(*GetExpressionType(*node.targetFunction)).funcIndex); } void Visit(ShaderAst::ConditionalExpression& /*node*/) override @@ -307,7 +307,7 @@ namespace Nz Append(type.GetResultingValue()); } - void GlslWriter::Append(const ShaderAst::FunctionType& functionType) + void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/) { throw std::runtime_error("unexpected FunctionType"); } @@ -829,8 +829,9 @@ namespace Nz { Visit(node.expr, true); - const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); - assert(IsStructType(exprType)); + const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr); + assert(exprType); + assert(IsStructType(*exprType)); for (const std::string& identifier : node.identifiers) Append(".", identifier); @@ -840,8 +841,9 @@ namespace Nz { Visit(node.expr, true); - const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); - assert(!IsStructType(exprType)); + const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr); + assert(exprType); + assert(!IsStructType(*exprType)); // Array access assert(node.indices.size() == 1); @@ -1326,9 +1328,10 @@ namespace Nz { assert(node.returnExpr); - const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr); - assert(IsStructType(returnType)); - std::size_t structIndex = std::get(returnType).structIndex; + const ShaderAst::ExpressionType* returnType = GetExpressionType(*node.returnExpr); + assert(returnType); + assert(IsStructType(*returnType)); + std::size_t structIndex = std::get(*returnType).structIndex; const auto& structData = Retrieve(m_currentState->structs, structIndex); std::string outputStructVarName; diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 2360c8c99..b68fd1f37 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -182,7 +182,7 @@ namespace Nz type.GetExpression()->Visit(*this); } - void LangWriter::Append(const ShaderAst::FunctionType& functionType) + void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/) { throw std::runtime_error("unexpected function type"); } diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 1f549d5a1..5e1ae37b6 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -263,7 +263,7 @@ namespace Nz::ShaderLang throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" }; auto& constantValue = SafeCast(*expr); - if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String }) + if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String }) throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" }; const std::string& versionStr = std::get(constantValue.value); @@ -302,7 +302,7 @@ namespace Nz::ShaderLang throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" }; auto& constantValue = SafeCast(*expr); - if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String }) + if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String }) throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" }; const std::string& uuidStr = std::get(constantValue.value); diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 74ac68961..69b549ff1 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -67,9 +67,9 @@ namespace Nz throw std::runtime_error("unexpected type"); }; - const ShaderAst::ExpressionType& resultType = GetExpressionType(node); - const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left); - const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right); + const ShaderAst::ExpressionType& resultType = *GetExpressionType(node); + const ShaderAst::ExpressionType& leftType = *GetExpressionType(*node.left); + const ShaderAst::ExpressionType& rightType = *GetExpressionType(*node.right); ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType); //ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType); @@ -405,7 +405,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node) { - std::size_t functionIndex = std::get(GetExpressionType(*node.targetFunction)).funcIndex; + std::size_t functionIndex = std::get(*GetExpressionType(*node.targetFunction)).funcIndex; UInt32 funcId = 0; for (const auto& [funcIndex, func] : m_funcData) @@ -434,7 +434,7 @@ namespace Nz UInt32 resultId = AllocateResultId(); m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender) { - appender(m_writer.GetTypeId(ShaderAst::GetExpressionType(node))); + appender(m_writer.GetTypeId(*ShaderAst::GetExpressionType(node))); appender(resultId); appender(funcId); @@ -718,9 +718,11 @@ namespace Nz { UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); - const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); - assert(IsVectorType(parameterType)); - UInt32 typeId = m_writer.GetTypeId(parameterType); + const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]); + assert(parameterType); + assert(IsVectorType(*parameterType)); + + UInt32 typeId = m_writer.GetTypeId(*parameterType); UInt32 firstParam = EvaluateExpression(node.parameters[0]); UInt32 secondParam = EvaluateExpression(node.parameters[1]); @@ -733,10 +735,11 @@ namespace Nz case ShaderAst::IntrinsicType::DotProduct: { - const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); - assert(IsVectorType(vecExprType)); + const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]); + assert(vecExprType); + assert(IsVectorType(*vecExprType)); - const ShaderAst::VectorType& vecType = std::get(vecExprType); + const ShaderAst::VectorType& vecType = std::get(*vecExprType); UInt32 typeId = m_writer.GetTypeId(vecType.type); @@ -754,9 +757,10 @@ namespace Nz { UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); - const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); - assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType)); - UInt32 typeId = m_writer.GetTypeId(parameterType); + const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]); + assert(parameterType); + assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType)); + UInt32 typeId = m_writer.GetTypeId(*parameterType); UInt32 param = EvaluateExpression(node.parameters[0]); UInt32 resultId = m_writer.AllocateResultId(); @@ -770,10 +774,11 @@ namespace Nz { UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); - const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); - assert(IsVectorType(vecExprType)); + const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]); + assert(vecExprType); + assert(IsVectorType(*vecExprType)); - const ShaderAst::VectorType& vecType = std::get(vecExprType); + const ShaderAst::VectorType& vecType = std::get(*vecExprType); UInt32 typeId = m_writer.GetTypeId(vecType.type); UInt32 vec = EvaluateExpression(node.parameters[0]); @@ -790,15 +795,16 @@ namespace Nz { UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); - const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); - assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType)); - UInt32 typeId = m_writer.GetTypeId(parameterType); + const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]); + assert(parameterType); + assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType)); + UInt32 typeId = m_writer.GetTypeId(*parameterType); ShaderAst::PrimitiveType basicType; - if (IsPrimitiveType(parameterType)) - basicType = std::get(parameterType); - else if (IsVectorType(parameterType)) - basicType = std::get(parameterType).type; + if (IsPrimitiveType(*parameterType)) + basicType = std::get(*parameterType); + else if (IsVectorType(*parameterType)) + basicType = std::get(*parameterType).type; else throw std::runtime_error("unexpected expression type"); @@ -837,10 +843,11 @@ namespace Nz { UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); - const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); - assert(IsVectorType(vecExprType)); + const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]); + assert(vecExprType); + assert(IsVectorType(*vecExprType)); - const ShaderAst::VectorType& vecType = std::get(vecExprType); + const ShaderAst::VectorType& vecType = std::get(*vecExprType); UInt32 typeId = m_writer.GetTypeId(vecType); UInt32 vec = EvaluateExpression(node.parameters[0]); @@ -856,9 +863,10 @@ namespace Nz { UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); - const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); - assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType)); - UInt32 typeId = m_writer.GetTypeId(parameterType); + const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]); + assert(parameterType); + assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType)); + UInt32 typeId = m_writer.GetTypeId(*parameterType); UInt32 firstParam = EvaluateExpression(node.parameters[0]); UInt32 secondParam = EvaluateExpression(node.parameters[1]); @@ -873,9 +881,10 @@ namespace Nz { UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); - const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); - assert(IsVectorType(parameterType)); - UInt32 typeId = m_writer.GetTypeId(parameterType); + const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]); + assert(parameterType); + assert(IsVectorType(*parameterType)); + UInt32 typeId = m_writer.GetTypeId(*parameterType); UInt32 firstParam = EvaluateExpression(node.parameters[0]); UInt32 secondParam = EvaluateExpression(node.parameters[1]); @@ -951,20 +960,22 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { - const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression); + const ShaderAst::ExpressionType* swizzledExpressionType = GetExpressionType(*node.expression); + assert(swizzledExpressionType); UInt32 exprResultId = EvaluateExpression(node.expression); - const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node); + const ShaderAst::ExpressionType* targetExprType = GetExpressionType(node); + assert(targetExprType); if (node.componentCount > 1) { - assert(IsVectorType(targetExprType)); + assert(IsVectorType(*targetExprType)); - const ShaderAst::VectorType& targetType = std::get(targetExprType); + const ShaderAst::VectorType& targetType = std::get(*targetExprType); UInt32 resultId = m_writer.AllocateResultId(); - if (IsVectorType(swizzledExpressionType)) + if (IsVectorType(*swizzledExpressionType)) { // Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) @@ -980,7 +991,7 @@ namespace Nz } else { - assert(IsPrimitiveType(swizzledExpressionType)); + assert(IsPrimitiveType(*swizzledExpressionType)); // Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) @@ -995,10 +1006,10 @@ namespace Nz PushResultId(resultId); } - else if (IsVectorType(swizzledExpressionType)) + else if (IsVectorType(*swizzledExpressionType)) { - assert(IsPrimitiveType(targetExprType)); - ShaderAst::PrimitiveType targetType = std::get(targetExprType); + assert(IsPrimitiveType(*targetExprType)); + ShaderAst::PrimitiveType targetType = std::get(*targetExprType); // Extract a single component from the vector assert(node.componentCount == 1); @@ -1011,8 +1022,8 @@ namespace Nz else { // Swizzling a primitive to itself (a.x for example), don't do anything - assert(IsPrimitiveType(swizzledExpressionType)); - assert(IsPrimitiveType(targetExprType)); + assert(IsPrimitiveType(*swizzledExpressionType)); + assert(IsPrimitiveType(*targetExprType)); assert(node.componentCount == 1); assert(node.components[0] == 0); @@ -1022,8 +1033,11 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node) { - const ShaderAst::ExpressionType& resultType = GetExpressionType(node); - const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expression); + const ShaderAst::ExpressionType* resultType = GetExpressionType(node); + assert(resultType); + + const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expression); + assert(exprType); UInt32 operand = EvaluateExpression(node.expression); @@ -1033,11 +1047,11 @@ namespace Nz { case ShaderAst::UnaryType::LogicalNot: { - assert(IsPrimitiveType(exprType)); - assert(std::get(resultType) == ShaderAst::PrimitiveType::Boolean); + assert(IsPrimitiveType(*exprType)); + assert(std::get(*resultType) == ShaderAst::PrimitiveType::Boolean); UInt32 resultId = m_writer.AllocateResultId(); - m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(resultType), resultId, operand); + m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(*resultType), resultId, operand); return resultId; } @@ -1045,10 +1059,10 @@ namespace Nz case ShaderAst::UnaryType::Minus: { ShaderAst::PrimitiveType basicType; - if (IsPrimitiveType(exprType)) - basicType = std::get(exprType); - else if (IsVectorType(exprType)) - basicType = std::get(exprType).type; + if (IsPrimitiveType(*exprType)) + basicType = std::get(*exprType); + else if (IsVectorType(*exprType)) + basicType = std::get(*exprType).type; else throw std::runtime_error("unexpected expression type"); @@ -1057,12 +1071,12 @@ namespace Nz switch (basicType) { case ShaderAst::PrimitiveType::Float32: - m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand); + m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(*resultType), resultId, operand); return resultId; case ShaderAst::PrimitiveType::Int32: case ShaderAst::PrimitiveType::UInt32: - m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(resultType), resultId, operand); + m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(*resultType), resultId, operand); return resultId; default: diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp index 91d87f693..690b535b0 100644 --- a/src/Nazara/Shader/SpirvExpressionLoad.cpp +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -76,9 +76,10 @@ namespace Nz { node.expr->Visit(*this); - const ShaderAst::ExpressionType& exprType = GetExpressionType(node); + const ShaderAst::ExpressionType* exprType = GetExpressionType(node); + assert(exprType); - UInt32 typeId = m_writer.GetTypeId(exprType); + UInt32 typeId = m_writer.GetTypeId(*exprType); assert(node.indices.size() == 1); UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front()); @@ -88,7 +89,7 @@ namespace Nz [&](const Pointer& pointer) { PointerChainAccess pointerChainAccess; - pointerChainAccess.exprType = &exprType; + pointerChainAccess.exprType = exprType; pointerChainAccess.indices = { indexId }; pointerChainAccess.pointedTypeId = pointer.pointedTypeId; pointerChainAccess.pointerId = pointer.pointerId; @@ -98,7 +99,7 @@ namespace Nz }, [&](PointerChainAccess& pointerChainAccess) { - pointerChainAccess.exprType = &exprType; + pointerChainAccess.exprType = exprType; pointerChainAccess.indices.push_back(indexId); }, [&](const Value& value) diff --git a/src/Nazara/Shader/SpirvExpressionStore.cpp b/src/Nazara/Shader/SpirvExpressionStore.cpp index fca711a33..b9c654e38 100644 --- a/src/Nazara/Shader/SpirvExpressionStore.cpp +++ b/src/Nazara/Shader/SpirvExpressionStore.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -61,11 +62,12 @@ namespace Nz } else { - const ShaderAst::ExpressionType& exprType = GetExpressionType(*node); - assert(swizzledPointer.componentCount == 1); - UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME + const ShaderAst::ExpressionType* exprType = GetExpressionType(*node); + assert(exprType); + + UInt32 pointerType = m_writer.RegisterPointerType(*exprType, swizzledPointer.storage); //< FIXME // Access chain UInt32 indexId = m_writer.GetConstantId(SafeCast(swizzledPointer.swizzleIndices[0])); @@ -86,14 +88,15 @@ namespace Nz { node.expr->Visit(*this); - const ShaderAst::ExpressionType& exprType = GetExpressionType(node); + const ShaderAst::ExpressionType* exprType = GetExpressionType(node); + assert(exprType); std::visit(Overloaded { [&](const Pointer& pointer) { UInt32 resultId = m_visitor.AllocateResultId(); - UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME + UInt32 pointerType = m_writer.RegisterPointerType(*exprType, pointer.storage); //< FIXME assert(node.indices.size() == 1); UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front()); @@ -117,13 +120,14 @@ namespace Nz { [&](const Pointer& pointer) { - const auto& expressionType = GetExpressionType(*node.expression); - assert(IsVectorType(expressionType)); + const ShaderAst::ExpressionType* expressionType = GetExpressionType(*node.expression); + assert(expressionType); + assert(IsVectorType(*expressionType)); SwizzledPointer swizzledPointer; swizzledPointer.pointerId = pointer.pointerId; swizzledPointer.storage = pointer.storage; - swizzledPointer.swizzledType = std::get(expressionType); + swizzledPointer.swizzledType = std::get(*expressionType); swizzledPointer.componentCount = node.componentCount; swizzledPointer.swizzleIndices = node.components; diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 3cf985e43..d6f555e02 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -98,7 +98,7 @@ namespace Nz for (const auto& parameter : node.parameters) { auto& var = func.variables.emplace_back(); - var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function)); + var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(*GetExpressionType(*parameter), SpirvStorageClass::Function)); } }