From ac9e7207ac5a115f59bd1dda262bd01cde17fab1 Mon Sep 17 00:00:00 2001 From: SirLynix Date: Wed, 30 Mar 2022 20:21:36 +0200 Subject: [PATCH] Shader: Add compiler and AST errors (WIP) I'm so afraid to lose all this work --- include/Nazara/Shader/Ast/AstReflect.hpp | 13 +- include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 48 +- include/Nazara/Shader/ShaderLangErrorList.hpp | 54 +- .../Ast/AstConstantPropagationVisitor.cpp | 10 +- src/Nazara/Shader/Ast/AstReflect.cpp | 18 +- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 1120 +++++++++-------- 6 files changed, 707 insertions(+), 556 deletions(-) diff --git a/include/Nazara/Shader/Ast/AstReflect.hpp b/include/Nazara/Shader/Ast/AstReflect.hpp index 0e4de1fcb..27755c868 100644 --- a/include/Nazara/Shader/Ast/AstReflect.hpp +++ b/include/Nazara/Shader/Ast/AstReflect.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -41,12 +42,12 @@ namespace Nz::ShaderAst std::function onStructDeclaration; std::function onVariableDeclaration; - std::function onAliasIndex; - std::function onConstIndex; - std::function onFunctionIndex; - std::function onOptionIndex; - std::function onStructIndex; - std::function onVariableIndex; + std::function onAliasIndex; + std::function onConstIndex; + std::function onFunctionIndex; + std::function onOptionIndex; + std::function onStructIndex; + std::function onVariableIndex; }; private: diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 4862547b6..77e2d8a6e 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -39,8 +39,6 @@ namespace Nz::ShaderAst SanitizeVisitor& operator=(const SanitizeVisitor&) = delete; SanitizeVisitor& operator=(SanitizeVisitor&&) = delete; - static UInt32 ToSwizzleIndex(char c); - struct Options { std::shared_ptr moduleResolver; @@ -113,13 +111,10 @@ namespace Nz::ShaderAst const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const; template const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; - ExpressionPtr HandleIdentifier(const IdentifierData* identifierData); - const ExpressionType* GetExpressionType(Expression& expr) const; const ExpressionType& GetExpressionTypeSecure(Expression& expr) const; - Expression& MandatoryExpr(const ExpressionPtr& node) const; - Statement& MandatoryStatement(const StatementPtr& node) const; + ExpressionPtr HandleIdentifier(const IdentifierData* identifierData, const ShaderLang::SourceLocation& sourceLocation); void PushScope(); void PopScope(); @@ -136,34 +131,31 @@ namespace Nz::ShaderAst void RegisterBuiltin(); - std::size_t RegisterAlias(std::string name, std::optional aliasData, std::optional index = {}); - std::size_t RegisterConstant(std::string name, std::optional value, std::optional index = {}); - std::size_t RegisterFunction(std::string name, std::optional funcData, std::optional index = {}); + std::size_t RegisterAlias(std::string name, std::optional aliasData, std::optional index, const ShaderLang::SourceLocation& sourceLocation); + std::size_t RegisterConstant(std::string name, std::optional value, std::optional index, const ShaderLang::SourceLocation& sourceLocation); + std::size_t RegisterFunction(std::string name, std::optional funcData, std::optional index, const ShaderLang::SourceLocation& sourceLocation); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex); - std::size_t RegisterStruct(std::string name, std::optional description, std::optional index = {}); - std::size_t RegisterType(std::string name, std::optional expressionType, std::optional index = {}); - std::size_t RegisterType(std::string name, std::optional partialType, std::optional index = {}); + std::size_t RegisterStruct(std::string name, std::optional description, std::optional index, const ShaderLang::SourceLocation& sourceLocation); + std::size_t RegisterType(std::string name, std::optional expressionType, std::optional index, const ShaderLang::SourceLocation& sourceLocation); + std::size_t RegisterType(std::string name, std::optional partialType, std::optional index, const ShaderLang::SourceLocation& sourceLocation); void RegisterUnresolved(std::string name); - std::size_t RegisterVariable(std::string name, std::optional type, std::optional index = {}); + std::size_t RegisterVariable(std::string name, std::optional type, std::optional index, const ShaderLang::SourceLocation& sourceLocation); - const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const; + const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier, const ShaderLang::SourceLocation& sourceLocation) const; void ResolveFunctions(); std::size_t ResolveStruct(const AliasType& aliasType); std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const StructType& structType); std::size_t ResolveStruct(const UniformType& uniformType); - ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false); - std::optional ResolveTypeExpr(const ExpressionValue& exprTypeValue, bool resolveAlias = false); + ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias, const ShaderLang::SourceLocation& sourceLocation); + std::optional ResolveTypeExpr(const ExpressionValue& exprTypeValue, bool resolveAlias, const ShaderLang::SourceLocation& sourceLocation); void SanitizeIdentifier(std::string& identifier); MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error); - ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const; - void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const; - - StatementPtr Unscope(StatementPtr node); + ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right, const ShaderLang::SourceLocation& sourceLocation); ValidationResult Validate(DeclareAliasStatement& node); ValidationResult Validate(WhileStatement& node); @@ -178,7 +170,21 @@ namespace Nz::ShaderAst ValidationResult Validate(SwizzleExpression& node); ValidationResult Validate(UnaryExpression& node); ValidationResult Validate(VariableValueExpression& node); - ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType); + ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType, const ShaderLang::SourceLocation& sourceLocation); + + template ValidationResult ValidateIntrinsicParamCount(IntrinsicExpression& node); + ValidationResult ValidateIntrinsicParamMatchingType(IntrinsicExpression& node); + template ValidationResult ValidateIntrinsicParameter(IntrinsicExpression& node, F&& func); + template ValidationResult ValidateIntrinsicParameterType(IntrinsicExpression& node, F&& func); + + static Expression& MandatoryExpr(const ExpressionPtr& node, const ShaderLang::SourceLocation& sourceLocation); + static Statement& MandatoryStatement(const StatementPtr& node, const ShaderLang::SourceLocation& sourceLocation); + + static void TypeMustMatch(const ExpressionType& left, const ExpressionType& right, const ShaderLang::SourceLocation& sourceLocation); + + static StatementPtr Unscope(StatementPtr node); + + static UInt32 ToSwizzleIndex(char c, const ShaderLang::SourceLocation& sourceLocation); enum class IdentifierCategory { diff --git a/include/Nazara/Shader/ShaderLangErrorList.hpp b/include/Nazara/Shader/ShaderLangErrorList.hpp index d53a2ea56..d8a91edd8 100644 --- a/include/Nazara/Shader/ShaderLangErrorList.hpp +++ b/include/Nazara/Shader/ShaderLangErrorList.hpp @@ -51,11 +51,63 @@ NAZARA_SHADERLANG_PARSER_ERROR(16, UnexpectedEndOfFile, "unexpected end of file" NAZARA_SHADERLANG_PARSER_ERROR(17, UnexpectedToken, "unexpected token {}", ShaderLang::TokenType) // Compiler errors -NAZARA_SHADERLANG_COMPILER_ERROR(1, InvalidSwizzle, "invalid swizzle {}", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, BinaryIncompatibleTypes, "incompatibles types ( and )") +NAZARA_SHADERLANG_COMPILER_ERROR(2, BinaryUnsupported, "{} type () does not support this binary operation", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, BranchOutsideOfFunction, "non-const branching statements can only exist inside a function") +NAZARA_SHADERLANG_COMPILER_ERROR(2, CastIncompatibleTypes, "incompatibles types ( and )") +NAZARA_SHADERLANG_COMPILER_ERROR(2, CastComponentMismatch, "component count doesn't match required component count") +NAZARA_SHADERLANG_COMPILER_ERROR(2, CircularImport, "circular import detected on {}", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, ConditionExpectedBool, "expected a boolean value") +NAZARA_SHADERLANG_COMPILER_ERROR(2, ConstMissingExpression, "const variables must have an expression") +NAZARA_SHADERLANG_COMPILER_ERROR(2, ConstantExpectedValue, "expected a value") +NAZARA_SHADERLANG_COMPILER_ERROR(2, ConstantExpressionRequired, "a constant expression is required in this context") +NAZARA_SHADERLANG_COMPILER_ERROR(2, DepthWriteAttribute, "only fragment entry-points can have the depth_write attribute") +NAZARA_SHADERLANG_COMPILER_ERROR(2, DiscardEarlyFragmentTests, "discard is not compatible with early fragment tests") +NAZARA_SHADERLANG_COMPILER_ERROR(2, EarlyFragmentTestsAttribute, "only functions with entry(frag) attribute can have the early_fragments_tests attribute") +NAZARA_SHADERLANG_COMPILER_ERROR(2, EntryFunctionParameter, "entry functions can either take one struct parameter or no parameter") +NAZARA_SHADERLANG_COMPILER_ERROR(2, EntryPointAlreadyDefined, "the same entry type has been defined multiple times") +NAZARA_SHADERLANG_COMPILER_ERROR(2, ExpectedFunction, "expected function expression") +NAZARA_SHADERLANG_COMPILER_ERROR(2, ExpectedIntrinsicFunction, "expected intrinsic function expression") +NAZARA_SHADERLANG_COMPILER_ERROR(2, ExtAlreadyDeclared, "external variable {} is already declared", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, ExtTypeNotAllowed, "external variable {} is of wrong type: only uniform and sampler are allowed in external blocks", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, ExtBindingAlreadyUsed, "binding (set={}, binding={}) is already in use", UInt32, UInt32) +NAZARA_SHADERLANG_COMPILER_ERROR(2, ExtMissingBindingIndex, "external variable requires a binding index") +NAZARA_SHADERLANG_COMPILER_ERROR(2, ForEachUnsupportedType, "for-each statements can only be called on array types, got ") +NAZARA_SHADERLANG_COMPILER_ERROR(2, FunctionCallOutsideOfFunction, "function calls must happen inside a function") +NAZARA_SHADERLANG_COMPILER_ERROR(2, FunctionDeclarationInsideFunction, "a function cannot be defined inside another function") +NAZARA_SHADERLANG_COMPILER_ERROR(2, IdentifierAlreadyUsed, "identifier {} is already used", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, IntrinsicExpectedParameterCount, "expected {} parameter(s)", unsigned int) +NAZARA_SHADERLANG_COMPILER_ERROR(2, IntrinsicExpectedFloat, "expected scalar or vector floating-points") +NAZARA_SHADERLANG_COMPILER_ERROR(2, IntrinsicExpectedType, "expected type for parameter #{}, got ", unsigned int) +NAZARA_SHADERLANG_COMPILER_ERROR(2, IntrinsicUnexpectedBoolean, "boolean parameters are not allowed") +NAZARA_SHADERLANG_COMPILER_ERROR(2, IntrinsicUnmatchingParameterType, "all types must match") +NAZARA_SHADERLANG_COMPILER_ERROR(2, InvalidScalarSwizzle, "invalid swizzle for scalar") +NAZARA_SHADERLANG_COMPILER_ERROR(2, InvalidSwizzle, "invalid swizzle {}", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, MissingOptionValue, "option {} requires a value (no default value set)", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, PartialTypeExpect, "expected a {} type at #{}", std::string, unsigned int) +NAZARA_SHADERLANG_COMPILER_ERROR(2, StructDeclarationInsideFunction, "structs must be declared outside of functions") +NAZARA_SHADERLANG_COMPILER_ERROR(2, VarDeclarationMissingTypeAndValue, "variable must either have a type or an initial value") +NAZARA_SHADERLANG_COMPILER_ERROR(2, VarDeclarationTypeUnmatching, "initial expression type () doesn't match specified type ()") +NAZARA_SHADERLANG_COMPILER_ERROR(2, UnexpectedAccessedType, "unexpected type (only struct and vectors can be indexed with identifiers)") +NAZARA_SHADERLANG_COMPILER_ERROR(2, UnaryUnsupported, "type () does not support this unary operation", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, UnmatchingTypes, "left expression type () doesn't match right expression type ()") +NAZARA_SHADERLANG_COMPILER_ERROR(2, UnknownField, "unknown field {}", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, UnknownMethod, "unknown method {}", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, UnknownIdentifier, "unknown identifier {}", std::string) +NAZARA_SHADERLANG_COMPILER_ERROR(2, WhileUnrollNotSupported, "unroll(always) is not yet supported on while") // AST errors NAZARA_SHADERLANG_AST_ERROR(1, AlreadyUsedIndex, "index {} is already used", std::size_t) +NAZARA_SHADERLANG_AST_ERROR(2, AlreadyUsedIndexPreregister, "cannot preregister used index {} as its already used", std::size_t) +NAZARA_SHADERLANG_AST_ERROR(2, EmptyIdentifier, "identifier cannot be empty") +NAZARA_SHADERLANG_AST_ERROR(2, Internal, "internal error: {}", std::string) +NAZARA_SHADERLANG_AST_ERROR(2, InvalidConstantIndex, "invalid constant index #{}", std::size_t) NAZARA_SHADERLANG_AST_ERROR(2, InvalidIndex, "invalid index {}", std::size_t) +NAZARA_SHADERLANG_AST_ERROR(2, MissingExpression, "a mandatory expression is missing") +NAZARA_SHADERLANG_AST_ERROR(2, MissingStatement, "a mandatory statement is missing") +NAZARA_SHADERLANG_AST_ERROR(2, NoIdentifier, "at least one identifier is required") +NAZARA_SHADERLANG_AST_ERROR(2, NoIndex, "at least one index is required") +NAZARA_SHADERLANG_AST_ERROR(2, UnexpectedIdentifier, "unexpected identifier of type {}", std::string) #undef NAZARA_SHADERLANG_ERROR #undef NAZARA_SHADERLANG_AST_ERROR diff --git a/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp b/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp index cd9b62c5b..b622f9480 100644 --- a/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstConstantPropagationVisitor.cpp @@ -811,6 +811,7 @@ namespace Nz::ShaderAst auto binary = ShaderBuilder::Binary(node.op, std::move(lhs), std::move(rhs)); binary->cachedExpressionType = node.cachedExpressionType; + binary->sourceLocation = node.sourceLocation; return binary; } @@ -931,6 +932,7 @@ namespace Nz::ShaderAst auto cast = ShaderBuilder::Cast(node.targetType.GetResultingValue(), std::move(expressions)); cast->cachedExpressionType = node.cachedExpressionType; + cast->sourceLocation = node.sourceLocation; return cast; } @@ -996,7 +998,10 @@ namespace Nz::ShaderAst if (!elseStatement) elseStatement = CloneStatement(node.elseStatement); - return ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)); + auto branchStatement = ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)); + branchStatement->sourceLocation = node.sourceLocation; + + return branchStatement; } ExpressionPtr AstConstantPropagationVisitor::Clone(ConditionalExpression& node) @@ -1031,6 +1036,7 @@ namespace Nz::ShaderAst auto constant = ShaderBuilder::Constant(*constantValue); constant->cachedExpressionType = GetConstantType(constant->value); + constant->sourceLocation = node.sourceLocation; return constant; } @@ -1082,6 +1088,7 @@ namespace Nz::ShaderAst auto swizzle = ShaderBuilder::Swizzle(std::move(expr), node.components, node.componentCount); swizzle->cachedExpressionType = node.cachedExpressionType; + swizzle->sourceLocation = node.sourceLocation; return swizzle; } @@ -1116,6 +1123,7 @@ namespace Nz::ShaderAst auto unary = ShaderBuilder::Unary(node.op, std::move(expr)); unary->cachedExpressionType = node.cachedExpressionType; + unary->sourceLocation = node.sourceLocation; return unary; } diff --git a/src/Nazara/Shader/Ast/AstReflect.cpp b/src/Nazara/Shader/Ast/AstReflect.cpp index dc11f9841..43c5684a1 100644 --- a/src/Nazara/Shader/Ast/AstReflect.cpp +++ b/src/Nazara/Shader/Ast/AstReflect.cpp @@ -21,7 +21,7 @@ namespace Nz::ShaderAst m_callbacks->onAliasDeclaration(node); if (m_callbacks->onAliasIndex && node.aliasIndex) - m_callbacks->onAliasIndex(node.name, *node.aliasIndex); + m_callbacks->onAliasIndex(node.name, *node.aliasIndex, node.sourceLocation); AstRecursiveVisitor::Visit(node); } @@ -33,7 +33,7 @@ namespace Nz::ShaderAst m_callbacks->onConstDeclaration(node); if (m_callbacks->onConstIndex && node.constIndex) - m_callbacks->onConstIndex(node.name, *node.constIndex); + m_callbacks->onConstIndex(node.name, *node.constIndex, node.sourceLocation); AstRecursiveVisitor::Visit(node); } @@ -49,7 +49,7 @@ namespace Nz::ShaderAst for (const auto& extVar : node.externalVars) { if (extVar.varIndex) - m_callbacks->onVariableIndex(extVar.name, *extVar.varIndex); + m_callbacks->onVariableIndex(extVar.name, *extVar.varIndex, extVar.sourceLocation); } } @@ -76,7 +76,7 @@ namespace Nz::ShaderAst for (const auto& parameter : node.parameters) { if (parameter.varIndex) - m_callbacks->onVariableIndex(parameter.name, *parameter.varIndex); + m_callbacks->onVariableIndex(parameter.name, *parameter.varIndex, parameter.sourceLocation); } } @@ -90,7 +90,7 @@ namespace Nz::ShaderAst m_callbacks->onOptionDeclaration(node); if (m_callbacks->onOptionIndex && node.optIndex) - m_callbacks->onOptionIndex(node.optName, *node.optIndex); + m_callbacks->onOptionIndex(node.optName, *node.optIndex, node.sourceLocation); AstRecursiveVisitor::Visit(node); } @@ -102,7 +102,7 @@ namespace Nz::ShaderAst m_callbacks->onStructDeclaration(node); if (m_callbacks->onStructIndex && node.structIndex) - m_callbacks->onStructIndex(node.description.name, *node.structIndex); + m_callbacks->onStructIndex(node.description.name, *node.structIndex, node.sourceLocation); AstRecursiveVisitor::Visit(node); } @@ -114,7 +114,7 @@ namespace Nz::ShaderAst m_callbacks->onVariableDeclaration(node); if (m_callbacks->onVariableIndex && node.varIndex) - m_callbacks->onVariableIndex(node.varName, *node.varIndex); + m_callbacks->onVariableIndex(node.varName, *node.varIndex, node.sourceLocation); AstRecursiveVisitor::Visit(node); } @@ -123,7 +123,7 @@ namespace Nz::ShaderAst { assert(m_callbacks); if (m_callbacks->onVariableIndex && node.varIndex) - m_callbacks->onVariableIndex(node.varName, *node.varIndex); + m_callbacks->onVariableIndex(node.varName, *node.varIndex, node.sourceLocation); AstRecursiveVisitor::Visit(node); } @@ -132,7 +132,7 @@ namespace Nz::ShaderAst { assert(m_callbacks); if (m_callbacks->onVariableIndex && node.varIndex) - m_callbacks->onVariableIndex(node.varName, *node.varIndex); + m_callbacks->onVariableIndex(node.varName, *node.varIndex, node.sourceLocation); AstRecursiveVisitor::Visit(node); } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 764cbe3fe..e0fa438e8 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -46,12 +46,12 @@ namespace Nz::ShaderAst Bitset preregisteredIndices; std::unordered_map values; - void PreregisterIndex(std::size_t index) + void PreregisterIndex(std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { if (index < availableIndices.GetSize()) { if (!availableIndices.Test(index)) - throw AstError{ "cannot preregister used index " + std::to_string(index) + " as its already used" }; + throw ShaderLang::AstAlreadyUsedIndexPreregisterError{ sourceLocation, index }; } else if (index >= availableIndices.GetSize()) availableIndices.Resize(index + 1, true); @@ -61,7 +61,7 @@ namespace Nz::ShaderAst } template - std::size_t Register(U&& data, std::optional index = {}) + std::size_t Register(U&& data, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { std::size_t dataIndex; if (index.has_value()) @@ -75,11 +75,11 @@ namespace Nz::ShaderAst if (preregisteredIndices.UnboundedTest(dataIndex)) preregisteredIndices.Reset(dataIndex); else - throw AstError{ "index " + std::to_string(dataIndex) + " is already used" }; + throw ShaderLang::AstInvalidIndexError{ sourceLocation, dataIndex }; } } else - dataIndex = RegisterNewIndex(); + dataIndex = RegisterNewIndex(false); assert(values.find(dataIndex) == values.end()); @@ -88,7 +88,7 @@ namespace Nz::ShaderAst return dataIndex; } - std::size_t RegisterNewIndex(bool preregister = false) + std::size_t RegisterNewIndex(bool preregister) { std::size_t index = availableIndices.FindFirst(); if (index == availableIndices.npos) @@ -105,22 +105,22 @@ namespace Nz::ShaderAst return index; } - T& Retrieve(std::size_t index) + T& Retrieve(std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { auto it = values.find(index); if (it == values.end()) - throw AstError{ "invalid index " + std::to_string(index) }; + throw ShaderLang::AstInvalidIndexError{ sourceLocation, index }; return it->second; } - T* TryRetrieve(std::size_t index) + T* TryRetrieve(std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { auto it = values.find(index); if (it == values.end()) { if (!preregisteredIndices.UnboundedTest(index)) - throw AstError{ "invalid index " + std::to_string(index) }; + throw ShaderLang::AstInvalidIndexError{ sourceLocation, index }; return nullptr; } @@ -248,41 +248,12 @@ namespace Nz::ShaderAst return clone; } - UInt32 SanitizeVisitor::ToSwizzleIndex(char c) - { - switch (c) - { - case 'r': - case 'x': - case 's': - return 0u; - - case 'g': - case 'y': - case 't': - return 1u; - - case 'b': - case 'z': - case 'p': - return 2u; - - case 'a': - case 'w': - case 'q': - return 3u; - - default: - throw AstError{ "unexpected character '" + std::string(&c, 1) + "' on swizzle " }; - } - } - ExpressionValue SanitizeVisitor::CloneType(const ExpressionValue& exprType) { if (!exprType.HasValue()) return {}; - std::optional resolvedType = ResolveTypeExpr(exprType); + std::optional resolvedType = ResolveTypeExpr(exprType, false, {}); if (!resolvedType.has_value()) return AstCloner::CloneType(exprType); @@ -292,9 +263,9 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node) { if (node.identifiers.empty()) - throw AstError{ "AccessIdentifierExpression must have at least one identifier" }; + throw ShaderLang::AstNoIdentifierError{ node.sourceLocation }; - MandatoryExpr(node.expr); + MandatoryExpr(node.expr, node.sourceLocation); // Handle module access (TODO: Add namespace expression?) if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1) @@ -303,12 +274,12 @@ namespace Nz::ShaderAst const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier); if (identifierData && identifierData->category == IdentifierCategory::Module) { - std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index); + std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index, node.sourceLocation); const auto& env = *m_context->modules[moduleIndex].environment; identifierData = FindIdentifier(env, node.identifiers.front().identifier); if (identifierData) - return HandleIdentifier(identifierData); + return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation); } } @@ -316,7 +287,7 @@ namespace Nz::ShaderAst for (const auto& identifierEntry : node.identifiers) { if (identifierEntry.identifier.empty()) - throw AstError{ "empty identifier" }; + throw ShaderLang::AstEmptyIdentifierError{ identifierEntry.sourceLocation }; const ExpressionType* exprType = GetExpressionType(*indexedExpr); if (!exprType) @@ -342,12 +313,12 @@ namespace Nz::ShaderAst indexedExpr = std::move(identifierExpr); } else - throw AstError{ "type has no method " + identifierEntry.identifier }; + throw ShaderLang::CompilerUnknownMethodError{ identifierEntry.sourceLocation }; } else if (IsStructType(resolvedType)) { std::size_t structIndex = ResolveStruct(resolvedType); - const StructDescription* s = m_context->structs.Retrieve(structIndex); + const StructDescription* s = m_context->structs.Retrieve(structIndex, indexedExpr->sourceLocation); // Retrieve member index (not counting disabled fields) Int32 fieldIndex = 0; @@ -361,7 +332,7 @@ namespace Nz::ShaderAst if (m_context->options.allowPartialSanitization) return AstCloner::Clone(node); //< unresolved - throw AstError{ "cond attribute is not constant" }; + throw ShaderLang::CompilerConstantExpressionRequiredError{ field.cond.GetExpression()->sourceLocation }; } else if (!field.cond.GetResultingValue()) continue; @@ -377,7 +348,7 @@ namespace Nz::ShaderAst } if (!fieldPtr) - throw AstError{ "unknown field " + identifierEntry.identifier }; + throw ShaderLang::CompilerUnknownFieldError{ indexedExpr->sourceLocation, identifierEntry.identifier }; if (m_context->options.useIdentifierAccessesForStructs) { @@ -386,25 +357,32 @@ namespace Nz::ShaderAst if (indexedExpr->GetType() != NodeType::AccessIdentifierExpression) { std::unique_ptr accessIndex = std::make_unique(); + accessIndex->sourceLocation = indexedExpr->sourceLocation; accessIndex->expr = std::move(indexedExpr); accessIdentifierPtr = accessIndex.get(); indexedExpr = std::move(accessIndex); } else + { accessIdentifierPtr = static_cast(indexedExpr.get()); + accessIdentifierPtr->sourceLocation.ExtendToRight(indexedExpr->sourceLocation); + } - accessIdentifierPtr->cachedExpressionType = ResolveTypeExpr(fieldPtr->type); + accessIdentifierPtr->cachedExpressionType = ResolveTypeExpr(fieldPtr->type, false, identifierEntry.sourceLocation); - accessIdentifierPtr->identifiers.emplace_back().identifier = fieldPtr->name; + auto& newIdentifierEntry = accessIdentifierPtr->identifiers.emplace_back(); + newIdentifierEntry.identifier = fieldPtr->name; + newIdentifierEntry.sourceLocation = indexedExpr->sourceLocation; } else { // Transform to AccessIndexExpression std::unique_ptr accessIndex = std::make_unique(); + accessIndex->sourceLocation = indexedExpr->sourceLocation; accessIndex->expr = std::move(indexedExpr); accessIndex->indices.push_back(ShaderBuilder::Constant(fieldIndex)); - accessIndex->cachedExpressionType = ResolveTypeExpr(fieldPtr->type); + accessIndex->cachedExpressionType = ResolveTypeExpr(fieldPtr->type, false, identifierEntry.sourceLocation); indexedExpr = std::move(accessIndex); } @@ -414,15 +392,14 @@ namespace Nz::ShaderAst // Swizzle expression std::size_t swizzleComponentCount = identifierEntry.identifier.size(); if (swizzleComponentCount > 4) - throw AstError{ "cannot swizzle more than four elements" }; + throw ShaderLang::CompilerInvalidSwizzleError{ identifierEntry.sourceLocation }; if (m_context->options.removeScalarSwizzling && IsPrimitiveType(resolvedType)) { for (std::size_t j = 0; j < swizzleComponentCount; ++j) { - if (ToSwizzleIndex(identifierEntry.identifier[j]) != 0) - throw AstError{ "invalid swizzle" }; - //throw ShaderLang::CompilerInvalidSwizzleError{}; + if (ToSwizzleIndex(identifierEntry.identifier[j], identifierEntry.sourceLocation) != 0) + throw ShaderLang::CompilerInvalidScalarSwizzleError{ identifierEntry.sourceLocation }; } if (swizzleComponentCount == 1) @@ -453,7 +430,7 @@ namespace Nz::ShaderAst swizzle->componentCount = swizzleComponentCount; for (std::size_t j = 0; j < swizzleComponentCount; ++j) - swizzle->components[j] = ToSwizzleIndex(identifierEntry.identifier[j]); + swizzle->components[j] = ToSwizzleIndex(identifierEntry.identifier[j], identifierEntry.sourceLocation); Validate(*swizzle); @@ -461,7 +438,7 @@ namespace Nz::ShaderAst } } else - throw AstError{ "unexpected type (only struct and vectors can be indexed with identifiers)" }; //< TODO: Add support for arrays + throw ShaderLang::CompilerUnexpectedAccessedTypeError{ node.sourceLocation }; } return indexedExpr; @@ -469,9 +446,9 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(AccessIndexExpression& node) { - MandatoryExpr(node.expr); + MandatoryExpr(node.expr, node.sourceLocation); for (auto& index : node.indices) - MandatoryExpr(index); + MandatoryExpr(index, node.sourceLocation); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); @@ -483,8 +460,8 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node) { - const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId)); - ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier); + const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId, node.sourceLocation), node.sourceLocation); + ExpressionPtr targetExpr = HandleIdentifier(targetIdentifier, node.sourceLocation); if (m_context->options.removeAliases) return targetExpr; @@ -502,8 +479,8 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node) { - MandatoryExpr(node.left); - MandatoryExpr(node.right); + MandatoryExpr(node.left, node.sourceLocation); + MandatoryExpr(node.right, node.sourceLocation); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); @@ -521,7 +498,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node) { - ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction)); + ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction, node.sourceLocation)); const ExpressionType* targetExprType = GetExpressionType(*targetExpr); if (!targetExprType) return AstCloner::Clone(node); //< unresolved type @@ -531,7 +508,7 @@ namespace Nz::ShaderAst if (IsFunctionType(resolvedType)) { if (!m_context->currentFunction) - throw AstError{ "function calls must happen inside a function" }; + throw ShaderLang::CompilerFunctionCallOutsideOfFunctionError{ node.sourceLocation }; std::size_t targetFuncIndex; if (targetExpr->GetType() == NodeType::FunctionExpression) @@ -540,16 +517,17 @@ namespace Nz::ShaderAst { const auto& alias = static_cast(*targetExpr); - const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId)); + const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId, node.sourceLocation), targetExpr->sourceLocation); if (targetIdentifier->category != IdentifierCategory::Function) - throw AstError{ "expected function expression" }; + throw ShaderLang::CompilerExpectedFunctionError{ targetExpr->sourceLocation }; targetFuncIndex = targetIdentifier->index; } else - throw AstError{ "expected function expression" }; + throw ShaderLang::CompilerExpectedFunctionError{ targetExpr->sourceLocation }; auto clone = std::make_unique(); + clone->sourceLocation = node.sourceLocation; clone->targetFunction = std::move(targetExpr); clone->parameters.reserve(node.parameters.size()); @@ -565,7 +543,7 @@ namespace Nz::ShaderAst else if (IsIntrinsicFunctionType(resolvedType)) { if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression) - throw AstError{ "expected intrinsic function expression" }; + throw ShaderLang::CompilerExpectedIntrinsicFunctionError{ targetExpr->sourceLocation }; std::size_t targetIntrinsicId = static_cast(*targetExpr).intrinsicId; @@ -575,7 +553,8 @@ namespace Nz::ShaderAst for (const auto& param : node.parameters) parameters.push_back(CloneExpression(param)); - auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics.Retrieve(targetIntrinsicId), std::move(parameters)); + auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics.Retrieve(targetIntrinsicId, node.sourceLocation), std::move(parameters)); + intrinsic->sourceLocation = node.sourceLocation; Validate(*intrinsic); return intrinsic; @@ -596,6 +575,7 @@ namespace Nz::ShaderAst assert(IsSamplerType(methodType.objectType->type) && methodType.methodIndex == 0); auto intrinsic = ShaderBuilder::Intrinsic(IntrinsicType::SampleTexture, std::move(parameters)); + intrinsic->sourceLocation = node.sourceLocation; Validate(*intrinsic); return intrinsic; @@ -604,10 +584,11 @@ namespace Nz::ShaderAst { // Calling a type - vec3[f32](0.0, 1.0, 2.0) - it's a cast auto clone = std::make_unique(); + clone->sourceLocation = node.sourceLocation; clone->targetType = *targetExprType; if (node.parameters.size() > clone->expressions.size()) - throw AstError{ "component count doesn't match required component count" }; + throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation }; for (std::size_t i = 0; i < node.parameters.size(); ++i) clone->expressions[i] = CloneExpression(node.parameters[i]); @@ -639,6 +620,7 @@ namespace Nz::ShaderAst } auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", targetType); //< Validation will prevent name-clash if required + variableDeclaration->sourceLocation = node.sourceLocation; Validate(*variableDeclaration); std::size_t variableIndex = *variableDeclaration->varIndex; @@ -649,6 +631,7 @@ namespace Nz::ShaderAst { // temp[i] auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, targetType), ShaderBuilder::Constant(UInt32(i))); + columnExpr->sourceLocation = node.sourceLocation; Validate(*columnExpr); // vector expression @@ -658,6 +641,7 @@ namespace Nz::ShaderAst { // fromMatrix[i] auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::Constant(UInt32(i))); + matrixColumnExpr->sourceLocation = node.sourceLocation; Validate(*matrixColumnExpr); vectorExpr = std::move(matrixColumnExpr); @@ -683,6 +667,7 @@ namespace Nz::ShaderAst expressions[j + 1] = ShaderBuilder::Constant(ExpressionType{ targetMatrixType.type }, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal vecCast = ShaderBuilder::Cast(ExpressionType{ VectorType{ targetMatrixType.rowCount, targetMatrixType.type } }, std::move(expressions)); + vecCast->sourceLocation = node.sourceLocation; Validate(*vecCast); castExpr = std::move(vecCast); @@ -693,6 +678,7 @@ namespace Nz::ShaderAst std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0); auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount); + swizzleExpr->sourceLocation = node.sourceLocation; Validate(*swizzleExpr); castExpr = std::move(swizzleExpr); @@ -702,10 +688,16 @@ namespace Nz::ShaderAst castExpr = std::move(vectorExpr); // temp[i] = castExpr - m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr)))); + auto assignExpr = ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr)); + assignExpr->sourceLocation = node.sourceLocation; + + m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(std::move(assignExpr))); } - return ShaderBuilder::Variable(variableIndex, targetType); + auto varExpr = ShaderBuilder::Variable(variableIndex, targetType); + varExpr->sourceLocation = node.sourceLocation; + + return varExpr; } return clone; @@ -713,9 +705,9 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node) { - MandatoryExpr(node.condition); - MandatoryExpr(node.truePath); - MandatoryExpr(node.falsePath); + MandatoryExpr(node.condition, node.sourceLocation); + MandatoryExpr(node.truePath, node.sourceLocation); + MandatoryExpr(node.falsePath, node.sourceLocation); ExpressionPtr cloneCondition = AstCloner::Clone(*node.condition); @@ -727,7 +719,7 @@ namespace Nz::ShaderAst } if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) - throw AstError{ "expected a boolean value" }; + throw ShaderLang::CompilerConditionExpectedBoolError{ cloneCondition->sourceLocation }; if (std::get(*conditionValue)) return AstCloner::Clone(*node.truePath); @@ -738,7 +730,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConstantValueExpression& node) { if (std::holds_alternative(node.value)) - throw std::runtime_error("expected a value"); + throw ShaderLang::CompilerConstantExpectedValueError{ node.sourceLocation }; auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); clone->cachedExpressionType = GetConstantType(clone->value); @@ -748,11 +740,11 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { - const ConstantValue* value = m_context->constantValues.TryRetrieve(node.constantId); + const ConstantValue* value = m_context->constantValues.TryRetrieve(node.constantId, node.sourceLocation); if (!value) { if (!m_context->options.allowPartialSanitization) - throw std::runtime_error("invalid constant index #" + std::to_string(node.constantId)); + throw ShaderLang::AstInvalidConstantIndexError{ node.sourceLocation, node.constantId }; return AstCloner::Clone(node); //< unresolved } @@ -760,6 +752,7 @@ namespace Nz::ShaderAst // Replace by constant value auto constant = ShaderBuilder::Constant(*value); constant->cachedExpressionType = GetConstantType(constant->value); + constant->sourceLocation = node.sourceLocation; return constant; } @@ -774,13 +767,13 @@ namespace Nz::ShaderAst if (m_context->allowUnknownIdentifiers) return AstCloner::Clone(node); - throw AstError{ "unknown identifier " + node.identifier }; + throw ShaderLang::CompilerUnknownIdentifierError{ node.sourceLocation, node.identifier }; } if (identifierData->category == IdentifierCategory::Unresolved) return AstCloner::Clone(node); - return HandleIdentifier(identifierData); + return HandleIdentifier(identifierData, node.sourceLocation); } ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) @@ -793,11 +786,17 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node) { - auto expression = CloneExpression(MandatoryExpr(node.expression)); + auto expression = CloneExpression(MandatoryExpr(node.expression, node.sourceLocation)); const ExpressionType* exprType = GetExpressionType(*expression); if (!exprType) - return ShaderBuilder::Swizzle(std::move(expression), node.components, node.componentCount); //< unresolved + { + auto swizzleExpr = ShaderBuilder::Swizzle(std::move(expression), node.components, node.componentCount); //< unresolved + swizzleExpr->cachedExpressionType = node.cachedExpressionType; + swizzleExpr->sourceLocation = node.sourceLocation; + + return swizzleExpr; + } const ExpressionType& resolvedExprType = ResolveAlias(*exprType); @@ -806,9 +805,9 @@ namespace Nz::ShaderAst for (std::size_t i = 0; i < node.componentCount; ++i) { if (node.components[i] != 0) - throw AstError{ "invalid swizzle" }; - + throw ShaderLang::CompilerInvalidScalarSwizzleError{ node.sourceLocation }; } + if (node.componentCount == 1) return expression; //< ignore this swizzle (a.x == a) @@ -822,6 +821,7 @@ namespace Nz::ShaderAst baseType = std::get(resolvedExprType); auto cast = std::make_unique(); + cast->sourceLocation = node.sourceLocation; cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } }; for (std::size_t j = 0; j < node.componentCount; ++j) cast->expressions[j] = CloneExpression(expression); @@ -836,6 +836,7 @@ namespace Nz::ShaderAst clone->componentCount = node.componentCount; clone->components = node.components; clone->expression = std::move(expression); + clone->sourceLocation = node.sourceLocation; Validate(*clone); return clone; @@ -865,14 +866,14 @@ namespace Nz::ShaderAst // Evaluate every condition at compilation and select the right statement for (auto& cond : node.condStatements) { - MandatoryExpr(cond.condition); + MandatoryExpr(cond.condition, node.sourceLocation); std::optional conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition)); if (!conditionValue.has_value()) return AstCloner::Clone(node); //< Unresolvable condition if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) - throw AstError{ "expected a boolean value" }; + throw ShaderLang::CompilerConditionExpectedBoolError{ cond.condition->sourceLocation }; if (std::get(*conditionValue)) return Unscope(AstCloner::Clone(*cond.statement)); @@ -889,7 +890,7 @@ namespace Nz::ShaderAst clone->condStatements.reserve(node.condStatements.size()); if (!m_context->currentFunction) - throw AstError{ "non-const branching statements can only exist inside a function" }; + throw ShaderLang::CompilerBranchOutsideOfFunctionError{ node.sourceLocation }; BranchStatement* root = clone.get(); for (std::size_t condIndex = 0; condIndex < node.condStatements.size(); ++condIndex) @@ -900,16 +901,16 @@ namespace Nz::ShaderAst auto BuildCondStatement = [&](BranchStatement::ConditionalStatement& condStatement) { - condStatement.condition = CloneExpression(MandatoryExpr(cond.condition)); + condStatement.condition = CloneExpression(MandatoryExpr(cond.condition, node.sourceLocation)); const ExpressionType* condType = GetExpressionType(*condStatement.condition); if (!condType) return ValidationResult::Unresolved; if (!IsPrimitiveType(*condType) || std::get(*condType) != PrimitiveType::Boolean) - throw AstError{ "branch expressions must resolve to boolean type" }; + throw ShaderLang::CompilerConditionExpectedBoolError{ condStatement.condition->sourceLocation }; - condStatement.statement = CloneStatement(MandatoryStatement(cond.statement)); + condStatement.statement = CloneStatement(MandatoryStatement(cond.statement, node.sourceLocation)); return ValidationResult::Validated; }; @@ -944,8 +945,8 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(ConditionalStatement& node) { - MandatoryExpr(node.condition); - MandatoryStatement(node.statement); + MandatoryExpr(node.condition, node.sourceLocation); + MandatoryStatement(node.statement, node.sourceLocation); ExpressionPtr cloneCondition = AstCloner::Clone(*node.condition); @@ -953,11 +954,14 @@ namespace Nz::ShaderAst if (!conditionValue.has_value()) { // Unresolvable condition - return ShaderBuilder::ConditionalStatement(std::move(cloneCondition), AstCloner::Clone(*node.statement)); + auto condStatement = ShaderBuilder::ConditionalStatement(std::move(cloneCondition), AstCloner::Clone(*node.statement)); + condStatement->sourceLocation = node.sourceLocation; + + return condStatement; } if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) - throw AstError{ "expected a boolean value" }; + throw ShaderLang::CompilerConditionExpectedBoolError{ cloneCondition->sourceLocation }; if (std::get(*conditionValue)) return AstCloner::Clone(*node.statement); @@ -981,15 +985,15 @@ namespace Nz::ShaderAst auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (!clone->expression) - throw AstError{ "const variables must have an expression" }; + throw ShaderLang::CompilerConstMissingExpressionError{ node.sourceLocation }; clone->expression = PropagateConstants(*clone->expression); if (clone->expression->GetType() != NodeType::ConstantValueExpression) { if (!m_context->options.allowPartialSanitization) - throw AstError{ "const variable must have constant expressions " }; + throw ShaderLang::CompilerConstantExpressionRequiredError{ clone->expression->sourceLocation }; - clone->constIndex = RegisterConstant(clone->name, std::nullopt, clone->constIndex); + clone->constIndex = RegisterConstant(clone->name, std::nullopt, clone->constIndex, node.sourceLocation); return clone; } @@ -997,14 +1001,14 @@ namespace Nz::ShaderAst ExpressionType expressionType = GetConstantType(value); - std::optional constType = ResolveTypeExpr(clone->type, true); + std::optional constType = ResolveTypeExpr(clone->type, true, node.sourceLocation); if (clone->type.HasValue() && constType.has_value() && *constType != ResolveAlias(expressionType)) - throw AstError{ "constant expression doesn't match type" }; + throw ShaderLang::CompilerVarDeclarationTypeUnmatchingError{ clone->expression->sourceLocation }; clone->type = expressionType; - clone->constIndex = RegisterConstant(clone->name, value, clone->constIndex); + clone->constIndex = RegisterConstant(clone->name, value, clone->constIndex, node.sourceLocation); if (m_context->options.removeConstDeclaration) return ShaderBuilder::NoOp(); @@ -1030,7 +1034,7 @@ namespace Nz::ShaderAst for (auto& extVar : clone->externalVars) { if (!extVar.bindingIndex.HasValue()) - throw AstError{ "external variable " + extVar.name + " requires a binding index" }; + throw ShaderLang::CompilerExtMissingBindingIndexError{ extVar.sourceLocation }; if (extVar.bindingSet.HasValue()) ComputeExprValue(extVar.bindingSet); @@ -1046,17 +1050,17 @@ namespace Nz::ShaderAst UInt64 bindingKey = bindingSet << 32 | bindingIndex; if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end()) - throw AstError{ "binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" }; + throw ShaderLang::CompilerExtBindingAlreadyUsedError{ extVar.sourceLocation, UInt32(bindingSet), UInt32(bindingIndex) }; m_context->usedBindingIndexes.insert(bindingKey); } if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end()) - throw AstError{ "external variable " + extVar.name + " is already declared" }; + throw ShaderLang::CompilerExtAlreadyDeclaredError{ extVar.sourceLocation, extVar.name }; m_context->declaredExternalVar.insert(extVar.name); - std::optional resolvedType = ResolveTypeExpr(extVar.type); + std::optional resolvedType = ResolveTypeExpr(extVar.type, false, node.sourceLocation); if (!resolvedType.has_value()) { RegisterUnresolved(extVar.name); @@ -1071,10 +1075,10 @@ namespace Nz::ShaderAst else if (IsSamplerType(targetType)) varType = targetType; else - throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; + throw ShaderLang::CompilerExtTypeNotAllowedError{ extVar.sourceLocation, extVar.name }; extVar.type = std::move(resolvedType).value(); - extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex); + extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex, extVar.sourceLocation); SanitizeIdentifier(extVar.name); } @@ -1085,7 +1089,7 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node) { if (m_context->currentFunction) - throw AstError{ "a function cannot be defined inside another function" }; + throw ShaderLang::CompilerFunctionDeclarationInsideFunctionError{ node.sourceLocation }; auto clone = std::make_unique(); clone->name = node.name; @@ -1097,6 +1101,7 @@ namespace Nz::ShaderAst cloneParam.name = parameter.name; cloneParam.type = CloneType(parameter.type); cloneParam.varIndex = parameter.varIndex; + cloneParam.sourceLocation = parameter.sourceLocation; } if (node.returnType.HasValue()) @@ -1123,21 +1128,31 @@ namespace Nz::ShaderAst if (!m_context->options.allowPartialSanitization) { if (m_context->entryFunctions[UnderlyingCast(stageType)]) - throw AstError{ "the same entry type has been defined multiple times" }; + throw ShaderLang::CompilerEntryPointAlreadyDefinedError{ clone->sourceLocation }; m_context->entryFunctions[UnderlyingCast(stageType)] = &node; } if (node.parameters.size() > 1) - throw AstError{ "entry functions can either take one struct parameter or no parameter" }; + throw ShaderLang::CompilerEntryFunctionParameterError{ node.parameters[1].sourceLocation }; + + if (!node.parameters.empty()) + { + auto& parameter = node.parameters.front(); + if (parameter.type.IsResultingValue()) + { + if (!IsStructType(ResolveAlias(parameter.type.GetResultingValue()))) + throw ShaderLang::CompilerEntryFunctionParameterError{ parameter.sourceLocation }; + } + } if (stageType != ShaderStageType::Fragment) { if (node.depthWrite.HasValue()) - throw AstError{ "only fragment entry-points can have the depth_write attribute" }; + throw ShaderLang::CompilerDepthWriteAttributeError{ node.sourceLocation }; if (node.earlyFragmentTests.HasValue()) - throw AstError{ "only functions with entry(frag) attribute can have the early_fragments_tests attribute" }; + throw ShaderLang::CompilerEarlyFragmentTestsAttributeError{ node.sourceLocation }; } } @@ -1149,13 +1164,13 @@ namespace Nz::ShaderAst if (clone->earlyFragmentTests.HasValue() && clone->earlyFragmentTests.GetResultingValue()) { //TODO: warning and disable early fragment tests - throw AstError{ "discard is not compatible with early fragment tests" }; + throw ShaderLang::CompilerDiscardEarlyFragmentTestsError{ node.sourceLocation }; } FunctionData funcData; funcData.node = clone.get(); //< update function node - std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData), node.funcIndex); + std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData), node.funcIndex, node.sourceLocation); clone->funcIndex = funcIndex; SanitizeIdentifier(clone->name); @@ -1170,16 +1185,16 @@ namespace Nz::ShaderAst auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (clone->optName.empty()) - throw AstError{ "empty option name" }; + throw ShaderLang::AstEmptyIdentifierError{ node.sourceLocation }; - std::optional resolvedOptionType = ResolveTypeExpr(clone->optType); + std::optional resolvedOptionType = ResolveTypeExpr(clone->optType, false, node.sourceLocation); if (!resolvedOptionType) { - clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex); + clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex, node.sourceLocation); return clone; } - ExpressionType resolvedType = ResolveType(*resolvedOptionType); + ExpressionType resolvedType = ResolveType(*resolvedOptionType, false, node.sourceLocation); const ExpressionType& targetType = ResolveAlias(resolvedType); if (clone->defaultValue) @@ -1187,12 +1202,12 @@ namespace Nz::ShaderAst const ExpressionType* defaultValueType = GetExpressionType(*clone->defaultValue); if (!defaultValueType) { - clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex); + clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex, node.sourceLocation); return clone; //< unresolved } if (targetType != *defaultValueType) - throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" }; + throw ShaderLang::CompilerVarDeclarationTypeUnmatchingError{ node.sourceLocation }; } clone->optType = std::move(resolvedType); @@ -1200,21 +1215,20 @@ namespace Nz::ShaderAst UInt32 optionHash = CRC32(reinterpret_cast(clone->optName.data()), clone->optName.size()); if (auto optionValueIt = m_context->options.optionValues.find(optionHash); optionValueIt != m_context->options.optionValues.end()) - clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, node.optIndex); + clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, node.optIndex, node.sourceLocation); else { if (m_context->options.allowPartialSanitization) { // Partial sanitization, we cannot give a value to this option - clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex); + clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex, node.sourceLocation); } else { - if (!clone->defaultValue) - throw AstError{ "missing option " + clone->optName + " value (has no default value)" }; + throw ShaderLang::CompilerMissingOptionValueError{ node.sourceLocation, clone->optName }; - clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), node.optIndex); + clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), node.optIndex, node.sourceLocation); } } @@ -1227,7 +1241,7 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node) { if (m_context->currentFunction) - throw AstError{ "structs must be declared outside of functions" }; + throw ShaderLang::CompilerStructDeclarationInsideFunctionError{ node.sourceLocation }; auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); @@ -1280,14 +1294,14 @@ namespace Nz::ShaderAst else if (IsStructType(targetType)) { std::size_t structIndex = std::get(targetType).structIndex; - const StructDescription* desc = m_context->structs.Retrieve(structIndex); + const StructDescription* desc = m_context->structs.Retrieve(structIndex, member.sourceLocation); if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) throw AstError{ "inner struct layout mismatch" }; } } } - clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex); + clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex, clone->sourceLocation); SanitizeIdentifier(clone->description.name); @@ -1317,7 +1331,7 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node) { - MandatoryExpr(node.expression); + MandatoryExpr(node.expression, node.sourceLocation); return AstCloner::Clone(node); } @@ -1327,10 +1341,10 @@ namespace Nz::ShaderAst if (node.varName.empty()) throw AstError{ "numerical for variable name cannot be empty" }; - auto fromExpr = CloneExpression(MandatoryExpr(node.fromExpr)); + auto fromExpr = CloneExpression(MandatoryExpr(node.fromExpr, node.sourceLocation)); auto stepExpr = CloneExpression(node.stepExpr); - auto toExpr = CloneExpression(MandatoryExpr(node.toExpr)); - MandatoryStatement(node.statement); + auto toExpr = CloneExpression(MandatoryExpr(node.toExpr, node.sourceLocation)); + MandatoryStatement(node.statement, node.sourceLocation); const ExpressionType* fromExprType = GetExpressionType(*fromExpr); const ExpressionType* toExprType = GetExpressionType(*fromExpr); @@ -1349,7 +1363,7 @@ namespace Nz::ShaderAst PushScope(); { if (fromExprType) - clone->varIndex = RegisterVariable(node.varName, *fromExprType, node.varIndex); + clone->varIndex = RegisterVariable(node.varName, *fromExprType, node.varIndex, node.sourceLocation); else { RegisterUnresolved(node.varName); @@ -1425,7 +1439,12 @@ namespace Nz::ShaderAst for (; counter < to; counter += step) { - auto var = ShaderBuilder::DeclareVariable(node.varName, ShaderBuilder::Constant(counter)); + auto constant = ShaderBuilder::Constant(counter); + constant->sourceLocation = node.sourceLocation; + + auto var = ShaderBuilder::DeclareVariable(node.varName, std::move(constant)); + var->sourceLocation = node.sourceLocation; + Validate(*var); multi->statements.emplace_back(std::move(var)); @@ -1461,6 +1480,7 @@ namespace Nz::ShaderAst // Counter variable auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr)); + counterVariable->sourceLocation = node.sourceLocation; counterVariable->varIndex = node.varIndex; Validate(*counterVariable); @@ -1469,6 +1489,7 @@ namespace Nz::ShaderAst // Target variable auto targetVariable = ShaderBuilder::DeclareVariable("to", std::move(toExpr)); + targetVariable->sourceLocation = node.sourceLocation; Validate(*targetVariable); std::size_t targetVarIndex = targetVariable->varIndex.value(); @@ -1480,6 +1501,7 @@ namespace Nz::ShaderAst if (stepExpr) { auto stepVariable = ShaderBuilder::DeclareVariable("step", std::move(stepExpr)); + stepVariable->sourceLocation = node.sourceLocation; Validate(*stepVariable); stepVarIndex = stepVariable->varIndex; @@ -1491,7 +1513,14 @@ namespace Nz::ShaderAst whileStatement->unroll = std::move(unrollValue); // While condition - auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, counterType), ShaderBuilder::Variable(targetVarIndex, counterType)); + auto conditionCounterVariable = ShaderBuilder::Variable(counterVarIndex, counterType); + conditionCounterVariable->sourceLocation = node.sourceLocation; + + auto conditionTargetVariable = ShaderBuilder::Variable(targetVarIndex, counterType); + conditionTargetVariable->sourceLocation = node.sourceLocation; + + auto condition = ShaderBuilder::Binary(BinaryType::CompLt, std::move(conditionCounterVariable), std::move(conditionTargetVariable)); + condition->sourceLocation = node.sourceLocation; Validate(*condition); whileStatement->condition = std::move(condition); @@ -1509,6 +1538,7 @@ namespace Nz::ShaderAst incrExpr = (counterType == PrimitiveType::Int32) ? ShaderBuilder::Constant(1) : ShaderBuilder::Constant(1u); auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType), std::move(incrExpr)); + incrCounter->sourceLocation = node.sourceLocation; Validate(*incrCounter); body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); @@ -1527,10 +1557,10 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(ForEachStatement& node) { - auto expr = CloneExpression(MandatoryExpr(node.expression)); + auto expr = CloneExpression(MandatoryExpr(node.expression, node.sourceLocation)); if (node.varName.empty()) - throw AstError{ "for-each variable name cannot be empty"}; + throw ShaderLang::AstEmptyIdentifierError{ node.sourceLocation }; const ExpressionType* exprType = GetExpressionType(*expr); if (!exprType) @@ -1545,7 +1575,7 @@ namespace Nz::ShaderAst innerType = arrayType.containedType->type; } else - throw AstError{ "for-each is only supported on arrays and range expressions" }; + throw ShaderLang::CompilerForEachUnsupportedTypeError{ node.sourceLocation }; ExpressionValue unrollValue; if (node.unroll.HasValue()) @@ -1559,6 +1589,8 @@ namespace Nz::ShaderAst // Repeat code auto multi = std::make_unique(); + multi->sourceLocation = node.sourceLocation; + if (IsArrayType(resolvedExprType)) { const ArrayType& arrayType = std::get(resolvedExprType); @@ -1644,10 +1676,11 @@ namespace Nz::ShaderAst clone->expression = std::move(expr); clone->varName = node.varName; clone->unroll = std::move(unrollValue); + clone->sourceLocation = node.sourceLocation; PushScope(); { - clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex); + clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex, node.sourceLocation); clone->statement = CloneStatement(node.statement); } PopScope(); @@ -1737,7 +1770,7 @@ namespace Nz::ShaderAst // Module has already been imported moduleIndex = it->second; if (moduleIndex == Context::ModuleIdSentinel) - throw AstError{ "circular import detected" }; + throw ShaderLang::CompilerCircularImportError{ node.sourceLocation, node.moduleName }; } auto& moduleData = m_context->modules[moduleIndex]; @@ -1808,7 +1841,7 @@ namespace Nz::ShaderAst m_context->currentStatementList = &clone->statements; for (auto& statement : node.statements) - clone->statements.push_back(AstCloner::Clone(MandatoryStatement(statement))); + clone->statements.push_back(AstCloner::Clone(MandatoryStatement(statement, node.sourceLocation))); m_context->currentStatementList = previousList; @@ -1817,7 +1850,7 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(ScopedStatement& node) { - MandatoryStatement(node.statement); + MandatoryStatement(node.statement, node.sourceLocation); PushScope(); @@ -1830,8 +1863,8 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(WhileStatement& node) { - MandatoryExpr(node.condition); - MandatoryStatement(node.body); + MandatoryExpr(node.condition, node.sourceLocation); + MandatoryStatement(node.body, node.sourceLocation); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (Validate(*clone) == ValidationResult::Unresolved) @@ -1840,7 +1873,7 @@ namespace Nz::ShaderAst if (clone->unroll.HasValue()) { if (ComputeExprValue(clone->unroll) == ValidationResult::Validated && clone->unroll.GetResultingValue() == LoopUnroll::Always) - throw AstError{ "unroll(always) is not yet supported on while" }; + throw ShaderLang::CompilerWhileUnrollNotSupportedError{ node.sourceLocation }; } return clone; @@ -1895,7 +1928,28 @@ namespace Nz::ShaderAst return &it->data; } - ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData) + const ExpressionType* SanitizeVisitor::GetExpressionType(Expression& expr) const + { + const ExpressionType* expressionType = ShaderAst::GetExpressionType(expr); + if (!expressionType) + { + if (!m_context->options.allowPartialSanitization) + throw ShaderLang::AstInternalError{ expr.sourceLocation, "unexpected missing expression type" }; + } + + return expressionType; + } + + const ExpressionType& SanitizeVisitor::GetExpressionTypeSecure(Expression& expr) const + { + const ExpressionType* expressionType = GetExpressionType(expr); + if (!expressionType) + throw ShaderLang::AstInternalError{ expr.sourceLocation, "unexpected missing expression type" }; + + return *expressionType; + } + + ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData, const ShaderLang::SourceLocation& sourceLocation) { switch (identifierData->category) { @@ -1903,6 +1957,7 @@ namespace Nz::ShaderAst { AliasValueExpression aliasValue; aliasValue.aliasId = identifierData->index; + aliasValue.sourceLocation = sourceLocation; return Clone(aliasValue); } @@ -1912,6 +1967,7 @@ namespace Nz::ShaderAst // Replace IdentifierExpression by Constant(Value)Expression ConstantExpression constantExpr; constantExpr.constantId = identifierData->index; + constantExpr.sourceLocation = sourceLocation; return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression } @@ -1922,30 +1978,33 @@ namespace Nz::ShaderAst auto funcExpr = std::make_unique(); funcExpr->cachedExpressionType = FunctionType{ identifierData->index }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type funcExpr->funcId = identifierData->index; + funcExpr->sourceLocation = sourceLocation; return funcExpr; } case IdentifierCategory::Intrinsic: { - IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index); + IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index, sourceLocation); // Replace IdentifierExpression by IntrinsicFunctionExpression auto intrinsicExpr = std::make_unique(); intrinsicExpr->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type intrinsicExpr->intrinsicId = identifierData->index; + intrinsicExpr->sourceLocation = sourceLocation; return intrinsicExpr; } case IdentifierCategory::Module: - throw AstError{ "unexpected module identifier" }; + throw ShaderLang::AstUnexpectedIdentifierError{ sourceLocation, "module" }; case IdentifierCategory::Struct: { // Replace IdentifierExpression by StructTypeExpression auto structExpr = std::make_unique(); structExpr->cachedExpressionType = StructType{ identifierData->index }; + structExpr->sourceLocation = sourceLocation; structExpr->structTypeId = identifierData->index; return structExpr; @@ -1955,63 +2014,28 @@ namespace Nz::ShaderAst { auto typeExpr = std::make_unique(); typeExpr->cachedExpressionType = Type{ identifierData->index }; + typeExpr->sourceLocation = sourceLocation; typeExpr->typeId = identifierData->index; return typeExpr; } case IdentifierCategory::Unresolved: - throw AstError{ "unexpected unresolved identifier" }; + throw ShaderLang::AstUnexpectedIdentifierError{ sourceLocation, "unresolved" }; case IdentifierCategory::Variable: { // Replace IdentifierExpression by VariableExpression auto varExpr = std::make_unique(); - varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index); + varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index, sourceLocation); + varExpr->sourceLocation = sourceLocation; varExpr->variableId = identifierData->index; return varExpr; } } - throw AstError{ "internal error" }; - } - - const ExpressionType* SanitizeVisitor::GetExpressionType(Expression& expr) const - { - const ExpressionType* expressionType = ShaderAst::GetExpressionType(expr); - if (!expressionType) - { - if (!m_context->options.allowPartialSanitization) - throw AstError{ "unexpected missing expression type" }; //< InternalError - } - - return expressionType; - } - - const ExpressionType& SanitizeVisitor::GetExpressionTypeSecure(Expression& expr) const - { - const ExpressionType* expressionType = GetExpressionType(expr); - if (!expressionType) - throw AstError{ "unexpected missing expression type" }; //< InternalError - - return *expressionType; - } - - Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const - { - if (!node) - throw AstError{ "Invalid expression" }; - - return *node; - } - - Statement& SanitizeVisitor::MandatoryStatement(const StatementPtr& node) const - { - if (!node) - throw AstError{ "Invalid statement" }; - - return *node; + throw ShaderLang::AstInternalError{ sourceLocation, "unhandled identifier category" }; } void SanitizeVisitor::PushScope() @@ -2040,6 +2064,7 @@ namespace Nz::ShaderAst Validate(*variableDeclaration); auto varExpr = std::make_unique(); + varExpr->sourceLocation = variableDeclaration->initialExpression->sourceLocation; varExpr->variableId = *variableDeclaration->varIndex; m_context->currentStatementList->push_back(std::move(variableDeclaration)); @@ -2054,7 +2079,7 @@ namespace Nz::ShaderAst if (optimizedExpr->GetType() != NodeType::ConstantValueExpression) { if (!m_context->options.allowPartialSanitization) - throw AstError{ "expected a constant expression" }; + throw ShaderLang::CompilerConstantExpressionRequiredError{ expr.sourceLocation }; return std::nullopt; } @@ -2138,9 +2163,9 @@ namespace Nz::ShaderAst std::unique_ptr SanitizeVisitor::PropagateConstants(T& node) const { AstConstantPropagationVisitor::Options optimizerOptions; - optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue* + optimizerOptions.constantQueryCallback = [&](std::size_t constantId) -> const ConstantValue* { - const ConstantValue* value = m_context->constantValues.TryRetrieve(constantId); + const ConstantValue* value = m_context->constantValues.TryRetrieve(constantId, node.sourceLocation); if (!value && !m_context->options.allowPartialSanitization) throw AstError{ "invalid constant index #" + std::to_string(constantId) }; @@ -2153,17 +2178,17 @@ namespace Nz::ShaderAst void SanitizeVisitor::PreregisterIndices(const Module& module) { - // If AST has been sanitized before and is sanitized again but with differents options that may introduce new variables (for example reduceLoopsToWhile) + // If AST has been sanitized before and is sanitized again but with different options that may introduce new variables (for example reduceLoopsToWhile) // we have to make sure we won't override variable indices. This is done by visiting the AST a first time and preregistering all indices. // TODO: Only do this is the AST has been already sanitized, maybe using a flag stored in the module? AstReflect::Callbacks registerCallbacks; - registerCallbacks.onAliasIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->aliases.PreregisterIndex(index); }; - registerCallbacks.onConstIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); }; - registerCallbacks.onFunctionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->functions.PreregisterIndex(index); }; - registerCallbacks.onOptionIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->constantValues.PreregisterIndex(index); }; - registerCallbacks.onStructIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->structs.PreregisterIndex(index); }; - registerCallbacks.onVariableIndex = [this](const std::string& /*name*/, std::size_t index) { m_context->variableTypes.PreregisterIndex(index); }; + registerCallbacks.onAliasIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->aliases.PreregisterIndex(index, sourceLocation); }; + registerCallbacks.onConstIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->constantValues.PreregisterIndex(index, sourceLocation); }; + registerCallbacks.onFunctionIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->functions.PreregisterIndex(index, sourceLocation); }; + registerCallbacks.onOptionIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->constantValues.PreregisterIndex(index, sourceLocation); }; + registerCallbacks.onStructIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->structs.PreregisterIndex(index, sourceLocation); }; + registerCallbacks.onVariableIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->variableTypes.PreregisterIndex(index, sourceLocation); }; AstReflect reflectVisitor; for (const auto& importedModule : module.importedModules) @@ -2174,7 +2199,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) { - auto& funcData = m_context->functions.Retrieve(funcIndex); + auto& funcData = m_context->functions.Retrieve(funcIndex, {}); funcData.flags |= flags; for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) @@ -2184,10 +2209,10 @@ namespace Nz::ShaderAst void SanitizeVisitor::RegisterBuiltin() { // Primitive types - RegisterType("bool", PrimitiveType::Boolean); - RegisterType("f32", PrimitiveType::Float32); - RegisterType("i32", PrimitiveType::Int32); - RegisterType("u32", PrimitiveType::UInt32); + RegisterType("bool", PrimitiveType::Boolean, std::nullopt, {}); + RegisterType("f32", PrimitiveType::Float32, std::nullopt, {}); + RegisterType("i32", PrimitiveType::Int32, std::nullopt, {}); + RegisterType("u32", PrimitiveType::UInt32, std::nullopt, {}); // Partial types @@ -2228,7 +2253,7 @@ namespace Nz::ShaderAst return arrayType; } - }); + }, std::nullopt, {}); // matX for (std::size_t componentCount = 2; componentCount <= 4; ++componentCount) @@ -2247,7 +2272,7 @@ namespace Nz::ShaderAst componentCount, componentCount, std::get(exprType) }; } - }); + }, std::nullopt, {}); } // vecX @@ -2267,7 +2292,7 @@ namespace Nz::ShaderAst componentCount, std::get(exprType) }; } - }); + }, std::nullopt, {}); } // samplers @@ -2312,7 +2337,7 @@ namespace Nz::ShaderAst sampler.imageType, primitiveType }; } - }); + }, std::nullopt, {}); } // uniform @@ -2331,7 +2356,7 @@ namespace Nz::ShaderAst structType }; } - }); + }, std::nullopt, {}); // Intrinsics RegisterIntrinsic("cross", IntrinsicType::CrossProduct); @@ -2345,17 +2370,17 @@ namespace Nz::ShaderAst RegisterIntrinsic("reflect", IntrinsicType::Reflect); } - std::size_t SanitizeVisitor::RegisterAlias(std::string name, std::optional aliasData, std::optional index) + std::size_t SanitizeVisitor::RegisterAlias(std::string name, std::optional aliasData, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t aliasIndex; if (aliasData) - aliasIndex = m_context->aliases.Register(std::move(*aliasData), index); + aliasIndex = m_context->aliases.Register(std::move(*aliasData), index, sourceLocation); else if (index) { - m_context->aliases.PreregisterIndex(*index); + m_context->aliases.PreregisterIndex(*index, sourceLocation); aliasIndex = *index; } else @@ -2370,17 +2395,17 @@ namespace Nz::ShaderAst return aliasIndex; } - std::size_t SanitizeVisitor::RegisterConstant(std::string name, std::optional value, std::optional index) + std::size_t SanitizeVisitor::RegisterConstant(std::string name, std::optional value, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t constantIndex; if (value) - constantIndex = m_context->constantValues.Register(std::move(*value), index); + constantIndex = m_context->constantValues.Register(std::move(*value), index, sourceLocation); else if (index) { - m_context->constantValues.PreregisterIndex(*index); + m_context->constantValues.PreregisterIndex(*index, sourceLocation); constantIndex = *index; } else @@ -2395,7 +2420,7 @@ namespace Nz::ShaderAst return constantIndex; } - std::size_t SanitizeVisitor::RegisterFunction(std::string name, std::optional funcData, std::optional index) + std::size_t SanitizeVisitor::RegisterFunction(std::string name, std::optional funcData, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (auto* identifier = FindIdentifier(name)) { @@ -2407,7 +2432,7 @@ namespace Nz::ShaderAst { if (funcData->node->entryStage.HasValue() && identifier->category == IdentifierCategory::Function) { - auto& otherFunction = m_context->functions.Retrieve(identifier->index); + auto& otherFunction = m_context->functions.Retrieve(identifier->index, sourceLocation); if (funcData->node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) duplicate = false; } @@ -2415,21 +2440,21 @@ namespace Nz::ShaderAst else { if (!m_context->options.allowPartialSanitization) - throw AstError{ "internal error" }; + throw ShaderLang::AstInternalError{ sourceLocation, "unexpected missing function data" }; duplicate = false; } if (duplicate) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; } std::size_t functionIndex; if (funcData) - functionIndex = m_context->functions.Register(std::move(*funcData), index); + functionIndex = m_context->functions.Register(std::move(*funcData), index, sourceLocation); else if (index) { - m_context->functions.PreregisterIndex(*index); + m_context->functions.PreregisterIndex(*index, sourceLocation); functionIndex = *index; } else @@ -2447,9 +2472,9 @@ namespace Nz::ShaderAst std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type) { if (FindIdentifier(name)) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ {}, name }; - std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type)); + std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type), std::nullopt, {}); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), @@ -2463,9 +2488,9 @@ namespace Nz::ShaderAst std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t index) { if (FindIdentifier(moduleIdentifier)) - throw AstError{ moduleIdentifier + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ {}, moduleIdentifier }; - std::size_t moduleIndex = m_context->moduleIndices.Register(index); + std::size_t moduleIndex = m_context->moduleIndices.Register(index, std::nullopt, {}); m_context->currentEnv->identifiersInScope.push_back({ std::move(moduleIdentifier), @@ -2476,17 +2501,17 @@ namespace Nz::ShaderAst return moduleIndex; } - std::size_t SanitizeVisitor::RegisterStruct(std::string name, std::optional description, std::optional index) + std::size_t SanitizeVisitor::RegisterStruct(std::string name, std::optional description, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t structIndex; if (description) - structIndex = m_context->structs.Register(*description, index); + structIndex = m_context->structs.Register(*description, index, sourceLocation); else if (index) { - m_context->structs.PreregisterIndex(*index); + m_context->structs.PreregisterIndex(*index, sourceLocation); structIndex = *index; } else @@ -2501,17 +2526,17 @@ namespace Nz::ShaderAst return structIndex; } - std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional expressionType, std::optional index) + std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional expressionType, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t typeIndex; if (expressionType) - typeIndex = m_context->types.Register(std::move(*expressionType), index); + typeIndex = m_context->types.Register(std::move(*expressionType), index, sourceLocation); else if (index) { - m_context->types.PreregisterIndex(*index); + m_context->types.PreregisterIndex(*index, sourceLocation); typeIndex = *index; } else @@ -2526,17 +2551,17 @@ namespace Nz::ShaderAst return typeIndex; } - std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional partialType, std::optional index) + std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional partialType, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t typeIndex; if (partialType) - typeIndex = m_context->types.Register(std::move(*partialType), index); + typeIndex = m_context->types.Register(std::move(*partialType), index, sourceLocation); else if (index) { - m_context->types.PreregisterIndex(*index); + m_context->types.PreregisterIndex(*index, sourceLocation); typeIndex = *index; } else @@ -2560,21 +2585,21 @@ namespace Nz::ShaderAst }); } - std::size_t SanitizeVisitor::RegisterVariable(std::string name, std::optional type, std::optional index) + std::size_t SanitizeVisitor::RegisterVariable(std::string name, std::optional type, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (auto* identifier = FindIdentifier(name)) { // Allow variable shadowing if (identifier->category != IdentifierCategory::Variable) - throw AstError{ name + " is already used" }; + throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; } std::size_t varIndex; if (type) - varIndex = m_context->variableTypes.Register(std::move(*type), index); + varIndex = m_context->variableTypes.Register(std::move(*type), index, sourceLocation); else if (index) { - m_context->variableTypes.PreregisterIndex(*index); + m_context->variableTypes.PreregisterIndex(*index, sourceLocation); varIndex = *index; } else @@ -2589,10 +2614,10 @@ namespace Nz::ShaderAst return varIndex; } - auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier) const -> const IdentifierData* + auto SanitizeVisitor::ResolveAliasIdentifier(const IdentifierData* identifier, const ShaderLang::SourceLocation& sourceLocation) const -> const IdentifierData* { while (identifier->category == IdentifierCategory::Alias) - identifier = &m_context->aliases.Retrieve(identifier->index); + identifier = &m_context->aliases.Retrieve(identifier->index, sourceLocation); return identifier; } @@ -2606,7 +2631,7 @@ namespace Nz::ShaderAst for (auto& parameter : pendingFunc.cloneNode->parameters) { - parameter.varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue(), parameter.varIndex); + parameter.varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue(), parameter.varIndex, parameter.sourceLocation); SanitizeIdentifier(parameter.name); } @@ -2621,7 +2646,7 @@ namespace Nz::ShaderAst pendingFunc.cloneNode->statements.reserve(pendingFunc.node->statements.size()); for (auto& statement : pendingFunc.node->statements) - pendingFunc.cloneNode->statements.push_back(CloneStatement(MandatoryStatement(statement))); + pendingFunc.cloneNode->statements.push_back(CloneStatement(MandatoryStatement(statement, pendingFunc.cloneNode->sourceLocation))); m_context->currentStatementList = previousList; m_context->currentFunction = nullptr; @@ -2629,7 +2654,7 @@ namespace Nz::ShaderAst std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex; for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) { - auto& targetFunc = m_context->functions.Retrieve(i); + auto& targetFunc = m_context->functions.Retrieve(i, pendingFunc.cloneNode->sourceLocation); targetFunc.calledByFunctions.UnboundedSet(funcIndex); } @@ -2704,7 +2729,7 @@ namespace Nz::ShaderAst return uniformType.containedType.structIndex; } - ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias) + ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias, const ShaderLang::SourceLocation& sourceLocation) { if (!IsTypeExpression(exprType)) { @@ -2716,20 +2741,20 @@ namespace Nz::ShaderAst std::size_t typeIndex = std::get(exprType).typeIndex; - const auto& type = m_context->types.Retrieve(typeIndex); + const auto& type = m_context->types.Retrieve(typeIndex, sourceLocation); if (std::holds_alternative(type)) throw AstError{ "full type expected" }; return std::get(type); } - std::optional SanitizeVisitor::ResolveTypeExpr(const ExpressionValue& exprTypeValue, bool resolveAlias) + std::optional SanitizeVisitor::ResolveTypeExpr(const ExpressionValue& exprTypeValue, bool resolveAlias, const ShaderLang::SourceLocation& sourceLocation) { if (!exprTypeValue.HasValue()) return NoType{}; if (exprTypeValue.IsResultingValue()) - return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias); + return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias, sourceLocation); assert(exprTypeValue.IsExpression()); ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression()); @@ -2740,7 +2765,7 @@ namespace Nz::ShaderAst //if (!IsTypeType(exprType)) // throw AstError{ "type expected" }; - return ResolveType(*exprType, resolveAlias); + return ResolveType(*exprType, resolveAlias, sourceLocation); } void SanitizeVisitor::SanitizeIdentifier(std::string& identifier) @@ -2786,33 +2811,17 @@ namespace Nz::ShaderAst return output; } - auto SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const -> ValidationResult + auto SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right, const ShaderLang::SourceLocation& sourceLocation) -> ValidationResult { const ExpressionType* leftType = GetExpressionType(*left); const ExpressionType* rightType = GetExpressionType(*right); if (!leftType || !rightType) return ValidationResult::Unresolved; - TypeMustMatch(*leftType, *rightType); + TypeMustMatch(*leftType, *rightType, sourceLocation); return ValidationResult::Validated; } - void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const - { - if (ResolveAlias(left) != ResolveAlias(right)) - throw AstError{ "Left expression type must match right expression type" }; - } - - StatementPtr SanitizeVisitor::Unscope(StatementPtr node) - { - assert(node); - - if (node->GetType() == NodeType::ScopedStatement) - return std::move(static_cast(*node).statement); - else - return node; - } - auto SanitizeVisitor::Validate(DeclareAliasStatement& node) -> ValidationResult { if (node.name.empty()) @@ -2843,27 +2852,27 @@ namespace Nz::ShaderAst else throw AstError{ "for now, only aliases, functions and structs can be aliased" }; - node.aliasIndex = RegisterAlias(node.name, targetIdentifier, node.aliasIndex); + node.aliasIndex = RegisterAlias(node.name, targetIdentifier, node.aliasIndex, node.sourceLocation); return ValidationResult::Validated; } auto SanitizeVisitor::Validate(WhileStatement& node) -> ValidationResult { - const ExpressionType* conditionType = GetExpressionType(MandatoryExpr(node.condition)); - MandatoryStatement(node.body); + const ExpressionType* conditionType = GetExpressionType(MandatoryExpr(node.condition, node.sourceLocation)); + MandatoryStatement(node.body, node.sourceLocation); if (!conditionType) return ValidationResult::Unresolved; if (ResolveAlias(*conditionType) != ExpressionType{ PrimitiveType::Boolean }) - throw AstError{ "expected a boolean value" }; + throw ShaderLang::CompilerConditionExpectedBoolError{ node.condition->sourceLocation }; return ValidationResult::Validated; } auto SanitizeVisitor::Validate(AccessIndexExpression& node) -> ValidationResult { - const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expr)); + const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expr, node.sourceLocation)); if (!exprType) return ValidationResult::Unresolved; @@ -2872,7 +2881,7 @@ namespace Nz::ShaderAst if (IsTypeExpression(resolvedExprType)) { std::size_t typeIndex = std::get(resolvedExprType).typeIndex; - const auto& type = m_context->types.Retrieve(typeIndex); + const auto& type = m_context->types.Retrieve(typeIndex, node.sourceLocation); if (!std::holds_alternative(type)) throw std::runtime_error("only partial types can be specialized"); @@ -2905,14 +2914,14 @@ namespace Nz::ShaderAst if (!indexExprType) return ValidationResult::Unresolved; - ExpressionType resolvedType = ResolveType(*indexExprType, true); + ExpressionType resolvedType = ResolveType(*indexExprType, true, node.sourceLocation); switch (partialType.parameters[i]) { case TypeParameterCategory::PrimitiveType: { if (!IsPrimitiveType(resolvedType)) - throw std::runtime_error("expected a primitive type"); + throw ShaderLang::CompilerPartialTypeExpectError{ indexExpr->sourceLocation, "primitive", SafeCast(i) }; break; } @@ -2920,7 +2929,7 @@ namespace Nz::ShaderAst case TypeParameterCategory::StructType: { if (!IsStructType(resolvedType)) - throw std::runtime_error("expected a struct type"); + throw ShaderLang::CompilerPartialTypeExpectError{ indexExpr->sourceLocation, "struct", SafeCast(i) }; break; } @@ -2941,7 +2950,7 @@ namespace Nz::ShaderAst else { if (node.indices.size() != 1) - throw AstError{ "AccessIndexExpression must have at one index" }; + throw ShaderLang::AstNoIndexError{ node.sourceLocation }; for (const auto& indexExpr : node.indices) { @@ -2972,9 +2981,9 @@ namespace Nz::ShaderAst Int32 index = std::get(constantExpr.value); std::size_t structIndex = ResolveStruct(resolvedExprType); - const StructDescription* s = m_context->structs.Retrieve(structIndex); + const StructDescription* s = m_context->structs.Retrieve(structIndex, indexExpr->sourceLocation); - std::optional resolvedExprTypeOpt = ResolveTypeExpr(s->members[index].type, true); + std::optional resolvedExprTypeOpt = ResolveTypeExpr(s->members[index].type, true, indexExpr->sourceLocation); if (!resolvedExprTypeOpt.has_value()) return ValidationResult::Unresolved; @@ -2996,7 +3005,7 @@ namespace Nz::ShaderAst resolvedExprType = swizzledVec.type; } else - throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; //< TODO: Add support for arrays + throw AstError{ "unexpected type (only struct, vectors and matrices can be indexed)" }; } node.cachedExpressionType = std::move(resolvedExprType); @@ -3007,14 +3016,11 @@ namespace Nz::ShaderAst auto SanitizeVisitor::Validate(AssignExpression& node) -> ValidationResult { - MandatoryExpr(node.left); - MandatoryExpr(node.right); - - const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left)); + const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left, node.sourceLocation)); if (!leftExprType) return ValidationResult::Unresolved; - const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right)); + const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right, node.sourceLocation)); if (!rightExprType) return ValidationResult::Unresolved; @@ -3025,7 +3031,7 @@ namespace Nz::ShaderAst switch (node.op) { case AssignType::Simple: - if (TypeMustMatch(node.left, node.right) == ValidationResult::Unresolved) + if (TypeMustMatch(node.left, node.right, node.sourceLocation) == ValidationResult::Unresolved) return ValidationResult::Unresolved; break; @@ -3040,8 +3046,8 @@ namespace Nz::ShaderAst if (binaryType) { - ExpressionType expressionType = ValidateBinaryOp(*binaryType, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType)); - TypeMustMatch(*leftExprType, expressionType); + ExpressionType expressionType = ValidateBinaryOp(*binaryType, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType), node.sourceLocation); + TypeMustMatch(*leftExprType, expressionType, node.sourceLocation); if (m_context->options.removeCompoundAssignments) { @@ -3057,18 +3063,15 @@ namespace Nz::ShaderAst auto SanitizeVisitor::Validate(BinaryExpression& node) -> ValidationResult { - MandatoryExpr(node.left); - MandatoryExpr(node.right); - - const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left)); + const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left, node.sourceLocation)); if (!leftExprType) return ValidationResult::Unresolved; - const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right)); + const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right, node.sourceLocation)); if (!rightExprType) return ValidationResult::Unresolved; - node.cachedExpressionType = ValidateBinaryOp(node.op, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType)); + node.cachedExpressionType = ValidateBinaryOp(node.op, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType), node.sourceLocation); return ValidationResult::Validated; } @@ -3081,7 +3084,7 @@ namespace Nz::ShaderAst { const auto& alias = static_cast(*node.targetFunction); - const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId)); + const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId, node.sourceLocation), node.sourceLocation); if (targetIdentifier->category != IdentifierCategory::Function) throw AstError{ "expected function expression" }; @@ -3090,7 +3093,7 @@ namespace Nz::ShaderAst else throw AstError{ "expected function expression" }; - auto& funcData = m_context->functions.Retrieve(targetFuncIndex); + auto& funcData = m_context->functions.Retrieve(targetFuncIndex, node.sourceLocation); const DeclareFunctionStatement* referenceDeclaration = funcData.node; @@ -3116,21 +3119,19 @@ namespace Nz::ShaderAst auto SanitizeVisitor::Validate(CastExpression& node) -> ValidationResult { - std::optional targetTypeOpt = ResolveTypeExpr(node.targetType); + std::optional targetTypeOpt = ResolveTypeExpr(node.targetType, false, node.sourceLocation); if (!targetTypeOpt) return ValidationResult::Unresolved; const ExpressionType& targetType = ResolveAlias(*targetTypeOpt); - const auto& firstExprPtr = node.expressions.front(); - if (!firstExprPtr) - throw AstError{ "expected at least one expression" }; + auto& firstExprPtr = MandatoryExpr(node.expressions.front(), node.sourceLocation); if (IsMatrixType(targetType)) { const MatrixType& targetMatrixType = std::get(targetType); - const ExpressionType* firstExprType = GetExpressionType(*firstExprPtr); + const ExpressionType* firstExprType = GetExpressionType(firstExprPtr); if (!firstExprType) return ValidationResult::Unresolved; @@ -3191,13 +3192,13 @@ namespace Nz::ShaderAst const ExpressionType& resolvedExprType = ResolveAlias(*exprType); if (!IsPrimitiveType(resolvedExprType) && !IsVectorType(resolvedExprType)) - throw AstError{ "incompatible type" }; + throw ShaderLang::CompilerCastIncompatibleTypesError{ exprPtr->sourceLocation }; componentCount += GetComponentCount(resolvedExprType); } if (componentCount != requiredComponents) - throw AstError{ "component count doesn't match required component count" }; + throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation }; } node.cachedExpressionType = targetType; @@ -3212,7 +3213,7 @@ namespace Nz::ShaderAst if (!node.varType.HasValue()) { if (!node.initialExpression) - throw AstError{ "variable must either have a type or an initial value" }; + throw ShaderLang::CompilerVarDeclarationMissingTypeAndValueError{ node.sourceLocation }; const ExpressionType* initialExprType = GetExpressionType(*node.initialExpression); if (!initialExprType) @@ -3225,7 +3226,7 @@ namespace Nz::ShaderAst } else { - std::optional varType = ResolveTypeExpr(node.varType); + std::optional varType = ResolveTypeExpr(node.varType, false, node.sourceLocation); if (!varType) { RegisterUnresolved(node.varName); @@ -3242,11 +3243,11 @@ namespace Nz::ShaderAst return ValidationResult::Unresolved; } - TypeMustMatch(resolvedType, *initialExprType); + TypeMustMatch(resolvedType, *initialExprType, node.sourceLocation); } } - node.varIndex = RegisterVariable(node.varName, resolvedType, node.varIndex); + node.varIndex = RegisterVariable(node.varName, resolvedType, node.varIndex, node.sourceLocation); node.varType = std::move(resolvedType); if (m_context->options.makeVariableNameUnique) @@ -3281,194 +3282,153 @@ namespace Nz::ShaderAst auto SanitizeVisitor::Validate(IntrinsicExpression& node) -> ValidationResult { - // Parameter validation + auto IsFloatingPointVector = [](const ExpressionType& type) + { + return type == ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }; + }; + + auto CheckNotBoolean = [](Expression& expression, const ExpressionType& type) + { + if ((IsPrimitiveType(type) && std::get(type) == PrimitiveType::Boolean) || + (IsVectorType(type) && std::get(type).type == PrimitiveType::Boolean)) + throw ShaderLang::CompilerIntrinsicUnexpectedBooleanError{ expression.sourceLocation }; + }; + + auto CheckFloatingPoint = [](Expression& expression, const ExpressionType& type) + { + if ((IsPrimitiveType(type) && std::get(type) != PrimitiveType::Float32) || + (IsVectorType(type) && std::get(type).type != PrimitiveType::Float32)) + throw ShaderLang::CompilerIntrinsicExpectedFloatError{ expression.sourceLocation }; + }; + + auto SetReturnTypeToFirstParameterType = [&] + { + node.cachedExpressionType = GetExpressionTypeSecure(*node.parameters.front()); + return ValidationResult::Validated; + }; + + auto SetReturnTypeToFirstParameterInnerType = [&] + { + node.cachedExpressionType = std::get(GetExpressionTypeSecure(*node.parameters.front())).type; + return ValidationResult::Validated; + }; + + auto IsUnresolved = [](ValidationResult result) { return result == ValidationResult::Unresolved; }; + + // Parameter validation and return type attribution switch (node.intrinsic) { case IntrinsicType::CrossProduct: - case IntrinsicType::DotProduct: - case IntrinsicType::Max: - case IntrinsicType::Min: - case IntrinsicType::Pow: - case IntrinsicType::Reflect: - { - if (node.parameters.size() != 2) - throw AstError { "Expected two parameters" }; - - for (auto& param : node.parameters) - MandatoryExpr(param); - - const ExpressionType* firstParameterType = GetExpressionType(*node.parameters.front()); - if (!firstParameterType) + if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) + || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) + || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector))) return ValidationResult::Unresolved; - for (std::size_t i = 1; i < node.parameters.size(); ++i) - { - const ExpressionType* parameterType = GetExpressionType(*node.parameters[i]); - if (!parameterType) - return ValidationResult::Unresolved; + return SetReturnTypeToFirstParameterType(); - if (ResolveAlias(*firstParameterType) != ResolveAlias(*parameterType)) - throw AstError{ "All type must match" }; + case IntrinsicType::DotProduct: + if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) + || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) + || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector))) + return ValidationResult::Unresolved; + + return SetReturnTypeToFirstParameterInnerType(); + + case IntrinsicType::Exp: + if (IsUnresolved(ValidateIntrinsicParamCount<1>(node)) + || IsUnresolved(ValidateIntrinsicParameter<0>(node, CheckFloatingPoint))) + return ValidationResult::Unresolved; + + return SetReturnTypeToFirstParameterType(); + + case IntrinsicType::Length: + if (IsUnresolved(ValidateIntrinsicParamCount<1>(node)) + || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector))) + return ValidationResult::Unresolved; + + return SetReturnTypeToFirstParameterInnerType(); + + case IntrinsicType::Max: + case IntrinsicType::Min: + if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) + || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) + || IsUnresolved(ValidateIntrinsicParameter<0>(node, CheckNotBoolean))) + return ValidationResult::Unresolved; + + return SetReturnTypeToFirstParameterType(); + + case IntrinsicType::Normalize: + if (IsUnresolved(ValidateIntrinsicParamCount<1>(node)) + || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector))) + return ValidationResult::Unresolved; + + return SetReturnTypeToFirstParameterType(); + + case IntrinsicType::Pow: + if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) + || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) + || IsUnresolved(ValidateIntrinsicParameter<0>(node, CheckFloatingPoint))) + return ValidationResult::Unresolved; + + return SetReturnTypeToFirstParameterType(); + + case IntrinsicType::Reflect: + if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) + || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) + || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector))) + return ValidationResult::Unresolved; + + return SetReturnTypeToFirstParameterType(); + + case IntrinsicType::SampleTexture: + { + if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) + || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsSamplerType))) + return ValidationResult::Unresolved; + + // Special check: vector dimensions must match sample type + const SamplerType& samplerType = std::get(ResolveAlias(GetExpressionTypeSecure(*node.parameters[0]))); + std::size_t requiredComponentCount = 0; + switch (samplerType.dim) + { + case ImageType::E1D: + requiredComponentCount = 1; + break; + + case ImageType::E1D_Array: + case ImageType::E2D: + requiredComponentCount = 2; + break; + + case ImageType::E2D_Array: + case ImageType::E3D: + case ImageType::Cubemap: + requiredComponentCount = 3; + break; } - break; - } + if (requiredComponentCount == 0) + throw ShaderLang::AstInternalError{ node.parameters[0]->sourceLocation, "unhandled sampler dimensions" }; - case IntrinsicType::Exp: - { - if (node.parameters.size() != 1) - throw AstError{ "Expected only one parameters" }; + auto IsRightType = [=](const ExpressionType& type) + { + return type == ExpressionType{ VectorType{ requiredComponentCount, PrimitiveType::Float32 } }; + }; - MandatoryExpr(node.parameters.front()); - break; - } - - case IntrinsicType::Length: - case IntrinsicType::Normalize: - { - if (node.parameters.size() != 1) - throw AstError{ "Expected only one parameters" }; - - const ExpressionType* type = GetExpressionType(MandatoryExpr(node.parameters.front())); - if (!type) + if (IsUnresolved(ValidateIntrinsicParameterType<1>(node, IsRightType))) return ValidationResult::Unresolved; - const ExpressionType& resolvedType = ResolveAlias(*type); - if (!IsVectorType(resolvedType)) - throw AstError{ "Expected a vector" }; - - break; - } - - case IntrinsicType::SampleTexture: - { - if (node.parameters.size() != 2) - throw AstError{ "Expected two parameters" }; - - for (auto& param : node.parameters) - MandatoryExpr(param); - - const ExpressionType* firstParameterType = GetExpressionType(*node.parameters[0]); - if (!firstParameterType) - return ValidationResult::Unresolved; - - if (!IsSamplerType(*firstParameterType)) - throw AstError{ "First parameter must be a sampler" }; - - const ExpressionType* secondParameterType = GetExpressionType(*node.parameters[1]); - if (!secondParameterType) - return ValidationResult::Unresolved; - - if (!IsVectorType(*secondParameterType)) - throw AstError{ "Second parameter must be a vector" }; - - break; + node.cachedExpressionType = VectorType{ 4, samplerType.sampledType }; + return ValidationResult::Validated; } } - // Return type attribution - switch (node.intrinsic) - { - case IntrinsicType::CrossProduct: - { - const ExpressionType* type = GetExpressionType(*node.parameters.front()); - if (!type) - return ValidationResult::Unresolved; - - if (ResolveAlias(*type) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) - throw AstError{ "CrossProduct only works with vec3[f32] expressions" }; - - node.cachedExpressionType = *type; - break; - } - - case IntrinsicType::DotProduct: - case IntrinsicType::Length: - { - const ExpressionType* type = GetExpressionType(*node.parameters.front()); - if (!type) - return ValidationResult::Unresolved; - - const ExpressionType& resolvedType = ResolveAlias(*type); - if (!IsVectorType(resolvedType)) - throw AstError{ "DotProduct expects vector types" }; //< FIXME - - node.cachedExpressionType = std::get(resolvedType).type; - break; - } - - case IntrinsicType::Normalize: - case IntrinsicType::Reflect: - { - const ExpressionType* type = GetExpressionType(*node.parameters.front()); - if (!type) - return ValidationResult::Unresolved; - - const ExpressionType& resolvedType = ResolveAlias(*type); - if (!IsVectorType(resolvedType)) - throw AstError{ "DotProduct expects vector types" }; //< FIXME - - node.cachedExpressionType = *type; - break; - } - - case IntrinsicType::Max: - case IntrinsicType::Min: - { - const ExpressionType* type = GetExpressionType(*node.parameters.front()); - if (!type) - return ValidationResult::Unresolved; - - const ExpressionType& resolvedType = ResolveAlias(*type); - if (!IsPrimitiveType(resolvedType) && !IsVectorType(resolvedType)) - throw AstError{ "max and min only work with primitive and vector types" }; - - if ((IsPrimitiveType(resolvedType) && std::get(resolvedType) == PrimitiveType::Boolean) || - (IsVectorType(resolvedType) && std::get(resolvedType).type == PrimitiveType::Boolean)) - throw AstError{ "max and min do not work with booleans" }; - - node.cachedExpressionType = *type; - break; - } - - case IntrinsicType::Exp: - case IntrinsicType::Pow: - { - const ExpressionType* type = GetExpressionType(*node.parameters.front()); - if (!type) - return ValidationResult::Unresolved; - - const ExpressionType& resolvedType = ResolveAlias(*type); - if (!IsPrimitiveType(resolvedType) && !IsVectorType(resolvedType)) - throw AstError{ "pow only works with primitive and vector types" }; - - if ((IsPrimitiveType(resolvedType) && std::get(resolvedType) != PrimitiveType::Float32) || - (IsVectorType(resolvedType) && std::get(resolvedType).type != PrimitiveType::Float32)) - throw AstError{ "pow only works with floating-point primitive or vectors" }; - - node.cachedExpressionType = *type; - break; - } - - case IntrinsicType::SampleTexture: - { - const ExpressionType* type = GetExpressionType(*node.parameters.front()); - if (!type) - return ValidationResult::Unresolved; - - const ExpressionType& resolvedType = ResolveAlias(*type); - node.cachedExpressionType = VectorType{ 4, std::get(resolvedType).sampledType }; - break; - } - } - - return ValidationResult::Validated; + throw ShaderLang::AstInternalError{ node.sourceLocation, "unhandled intrinsic" }; } auto SanitizeVisitor::Validate(SwizzleExpression& node) -> ValidationResult { - MandatoryExpr(node.expression); - const ExpressionType* exprType = GetExpressionType(*node.expression); + const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expression, node.sourceLocation)); if (!exprType) return ValidationResult::Unresolved; @@ -3482,7 +3442,7 @@ namespace Nz::ShaderAst if (IsPrimitiveType(resolvedExprType)) { if (m_context->options.removeScalarSwizzling) - throw AstError{ "internal error" }; //< scalar swizzling should have been removed by then + throw ShaderLang::AstInternalError{ node.sourceLocation, "scalar swizzling should have been removed before validating" }; baseType = std::get(resolvedExprType); componentCount = 1; @@ -3495,12 +3455,12 @@ namespace Nz::ShaderAst } if (node.componentCount > 4) - throw AstError{ "cannot swizzle more than four elements" }; + throw ShaderLang::CompilerInvalidSwizzleError{ node.sourceLocation }; for (std::size_t i = 0; i < node.componentCount; ++i) { if (node.components[i] >= componentCount) - throw AstError{ "invalid swizzle" }; + throw ShaderLang::CompilerInvalidSwizzleError{ node.sourceLocation }; } if (node.componentCount > 1) @@ -3518,7 +3478,7 @@ namespace Nz::ShaderAst auto SanitizeVisitor::Validate(UnaryExpression& node) -> ValidationResult { - const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expression)); + const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expression, node.sourceLocation)); if (!exprType) return ValidationResult::Unresolved; @@ -3529,7 +3489,7 @@ namespace Nz::ShaderAst case UnaryType::LogicalNot: { if (resolvedExprType != ExpressionType(PrimitiveType::Boolean)) - throw AstError{ "logical not is only supported on booleans" }; + throw ShaderLang::CompilerUnaryUnsupportedError{ node.sourceLocation }; break; } @@ -3543,10 +3503,10 @@ namespace Nz::ShaderAst else if (IsVectorType(resolvedExprType)) basicType = std::get(resolvedExprType).type; else - throw AstError{ "plus and minus unary expressions are only supported on primitive/vectors types" }; + throw ShaderLang::CompilerUnaryUnsupportedError{ node.sourceLocation }; if (basicType != PrimitiveType::Float32 && basicType != PrimitiveType::Int32 && basicType != PrimitiveType::UInt32) - throw AstError{ "plus and minus unary expressions are only supported on floating points and integers types" }; + throw ShaderLang::CompilerUnaryUnsupportedError{ node.sourceLocation }; break; } @@ -3558,17 +3518,17 @@ namespace Nz::ShaderAst auto SanitizeVisitor::Validate(VariableValueExpression& node) -> ValidationResult { - node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId); + node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId, node.sourceLocation); return ValidationResult::Validated; } - ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType) + ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType, const ShaderLang::SourceLocation& sourceLocation) { if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) - throw AstError{ "left expression type does not support binary operation" }; + throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left" }; if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType)) - throw AstError{ "right expression type does not support binary operation" }; + throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "right" }; if (IsPrimitiveType(leftExprType)) { @@ -3580,19 +3540,19 @@ namespace Nz::ShaderAst case BinaryType::CompLe: case BinaryType::CompLt: if (leftType == PrimitiveType::Boolean) - throw AstError{ "this operation is not supported for booleans" }; + throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left" }; [[fallthrough]]; case BinaryType::CompEq: case BinaryType::CompNe: { - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; } case BinaryType::Add: case BinaryType::Subtract: - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; case BinaryType::Multiply: @@ -3606,30 +3566,30 @@ namespace Nz::ShaderAst { if (IsMatrixType(rightExprType)) { - TypeMustMatch(leftType, std::get(rightExprType).type); + TypeMustMatch(leftType, std::get(rightExprType).type, sourceLocation); return rightExprType; } else if (IsPrimitiveType(rightExprType)) { - TypeMustMatch(leftType, rightExprType); + TypeMustMatch(leftType, rightExprType, sourceLocation); return leftExprType; } else if (IsVectorType(rightExprType)) { - TypeMustMatch(leftType, std::get(rightExprType).type); + TypeMustMatch(leftType, std::get(rightExprType).type, sourceLocation); return rightExprType; } else - throw AstError{ "incompatible types" }; + throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation }; break; } case PrimitiveType::Boolean: - throw AstError{ "this operation is not supported for booleans" }; + throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left" }; default: - throw AstError{ "incompatible types" }; + throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation }; } } @@ -3637,9 +3597,9 @@ namespace Nz::ShaderAst case BinaryType::LogicalOr: { if (leftType != PrimitiveType::Boolean) - throw AstError{ "logical and/or are only supported on booleans" }; + throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left" }; - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; } } @@ -3655,12 +3615,12 @@ namespace Nz::ShaderAst case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; case BinaryType::Add: case BinaryType::Subtract: - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; case BinaryType::Multiply: @@ -3668,31 +3628,31 @@ namespace Nz::ShaderAst { if (IsMatrixType(rightExprType)) { - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; //< FIXME } else if (IsPrimitiveType(rightExprType)) { - TypeMustMatch(leftType.type, rightExprType); + TypeMustMatch(leftType.type, rightExprType, sourceLocation); return leftExprType; } else if (IsVectorType(rightExprType)) { const VectorType& rightType = std::get(rightExprType); - TypeMustMatch(leftType.type, rightType.type); + TypeMustMatch(leftType.type, rightType.type, sourceLocation); if (leftType.columnCount != rightType.componentCount) - throw AstError{ "incompatible types" }; + throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation }; return rightExprType; } else - throw AstError{ "incompatible types" }; + throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation }; } case BinaryType::LogicalAnd: case BinaryType::LogicalOr: - throw AstError{ "logical and/or are only supported on booleans" }; + throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left" }; } } else if (IsVectorType(leftExprType)) @@ -3706,12 +3666,12 @@ namespace Nz::ShaderAst case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; case BinaryType::Add: case BinaryType::Subtract: - TypeMustMatch(leftExprType, rightExprType); + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; case BinaryType::Multiply: @@ -3719,26 +3679,150 @@ namespace Nz::ShaderAst { if (IsPrimitiveType(rightExprType)) { - TypeMustMatch(leftType.type, rightExprType); + TypeMustMatch(leftType.type, rightExprType, sourceLocation); return leftExprType; } else if (IsVectorType(rightExprType)) { - TypeMustMatch(leftType, rightExprType); + TypeMustMatch(leftType, rightExprType, sourceLocation); return rightExprType; } else - throw AstError{ "incompatible types" }; + throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation }; break; } case BinaryType::LogicalAnd: case BinaryType::LogicalOr: - throw AstError{ "logical and/or are only supported on booleans" }; + throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left" }; } } - throw AstError{ "internal error: unchecked operation" }; + throw ShaderLang::AstInternalError{ sourceLocation, "unchecked operation" }; + } + + template + auto SanitizeVisitor::ValidateIntrinsicParamCount(IntrinsicExpression& node) -> ValidationResult + { + if (node.parameters.size() != N) + throw ShaderLang::CompilerIntrinsicExpectedParameterCountError{ node.sourceLocation, SafeCast(N) }; + + for (auto& param : node.parameters) + MandatoryExpr(param, node.sourceLocation); + + return ValidationResult::Validated; + } + + auto SanitizeVisitor::ValidateIntrinsicParamMatchingType(IntrinsicExpression& node) -> ValidationResult + { + const ExpressionType* firstParameterType = GetExpressionType(*node.parameters.front()); + if (!firstParameterType) + return ValidationResult::Unresolved; + + for (std::size_t i = 1; i < node.parameters.size(); ++i) + { + const ExpressionType* parameterType = GetExpressionType(*node.parameters[i]); + if (!parameterType) + return ValidationResult::Unresolved; + + if (ResolveAlias(*firstParameterType) != ResolveAlias(*parameterType)) + throw ShaderLang::CompilerIntrinsicUnmatchingParameterTypeError{ node.parameters[i]->sourceLocation }; + } + + return ValidationResult::Validated; + } + + template + auto SanitizeVisitor::ValidateIntrinsicParameter(IntrinsicExpression& node, F&& func) -> ValidationResult + { + assert(node.parameters.size() > N); + auto& parameter = MandatoryExpr(node.parameters[N], node.sourceLocation); + const ExpressionType* type = GetExpressionType(parameter); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + func(parameter, resolvedType); + + return ValidationResult::Validated; + } + + template + auto SanitizeVisitor::ValidateIntrinsicParameterType(IntrinsicExpression& node, F&& func) -> ValidationResult + { + assert(node.parameters.size() > N); + auto& parameter = MandatoryExpr(node.parameters[N], node.sourceLocation); + + const ExpressionType* type = GetExpressionType(parameter); + if (!type) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedType = ResolveAlias(*type); + if (!func(resolvedType)) + throw ShaderLang::CompilerIntrinsicExpectedTypeError{ parameter.sourceLocation, SafeCast(N) }; + + return ValidationResult::Validated; + } + + Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node, const ShaderLang::SourceLocation& sourceLocation) + { + if (!node) + throw ShaderLang::AstMissingExpressionError{ sourceLocation }; + + return *node; + } + + Statement& SanitizeVisitor::MandatoryStatement(const StatementPtr& node, const ShaderLang::SourceLocation& sourceLocation) + { + if (!node) + throw ShaderLang::AstMissingStatementError{ sourceLocation }; + + return *node; + } + + void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right, const ShaderLang::SourceLocation& sourceLocation) + { + if (ResolveAlias(left) != ResolveAlias(right)) + throw ShaderLang::CompilerUnmatchingTypesError{ sourceLocation }; + } + + StatementPtr SanitizeVisitor::Unscope(StatementPtr node) + { + assert(node); + + if (node->GetType() == NodeType::ScopedStatement) + return std::move(static_cast(*node).statement); + else + return node; + } + + UInt32 SanitizeVisitor::ToSwizzleIndex(char c, const ShaderLang::SourceLocation& sourceLocation) + { + switch (c) + { + case 'r': + case 'x': + case 's': + return 0u; + + case 'g': + case 'y': + case 't': + return 1u; + + case 'b': + case 'z': + case 'p': + return 2u; + + case 'a': + case 'w': + case 'q': + return 3u; + + default: + throw ShaderLang::CompilerInvalidSwizzleError{ sourceLocation, std::string(&c, 1) }; + } } }