// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) // This file is part of the "Nazara Engine - Shader module" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace Nz::ShaderAst { struct SanitizeVisitor::CurrentFunctionData { std::optional stageType; Bitset<> calledFunctions; DeclareFunctionStatement* statement; FunctionFlags flags; }; template struct SanitizeVisitor::IdentifierList { Bitset availableIndices; Bitset preregisteredIndices; std::unordered_map values; void PreregisterIndex(std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { if (index < availableIndices.GetSize()) { if (!availableIndices.Test(index) && !preregisteredIndices.UnboundedTest(index)) throw ShaderLang::AstAlreadyUsedIndexPreregisterError{ sourceLocation, index }; } else if (index >= availableIndices.GetSize()) availableIndices.Resize(index + 1, true); availableIndices.Set(index, false); preregisteredIndices.UnboundedSet(index); } template std::size_t Register(U&& data, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { std::size_t dataIndex; if (index.has_value()) { dataIndex = *index; if (dataIndex >= availableIndices.GetSize()) availableIndices.Resize(dataIndex + 1, true); else if (!availableIndices.Test(dataIndex)) { if (preregisteredIndices.UnboundedTest(dataIndex)) preregisteredIndices.Reset(dataIndex); else throw ShaderLang::AstInvalidIndexError{ sourceLocation, dataIndex }; } } else dataIndex = RegisterNewIndex(false); assert(values.find(dataIndex) == values.end()); availableIndices.Set(dataIndex, false); values.emplace(dataIndex, std::forward(data)); return dataIndex; } std::size_t RegisterNewIndex(bool preregister) { std::size_t index = availableIndices.FindFirst(); if (index == availableIndices.npos) { index = availableIndices.GetSize(); availableIndices.Resize(index + 1, true); } availableIndices.Set(index, false); if (preregister) preregisteredIndices.UnboundedSet(index); return index; } T& Retrieve(std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { auto it = values.find(index); if (it == values.end()) throw ShaderLang::AstInvalidIndexError{ sourceLocation, index }; return it->second; } T* TryRetrieve(std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { auto it = values.find(index); if (it == values.end()) { if (!preregisteredIndices.UnboundedTest(index)) throw ShaderLang::AstInvalidIndexError{ sourceLocation, index }; return nullptr; } return &it->second; } }; struct SanitizeVisitor::Scope { std::size_t previousSize; }; struct SanitizeVisitor::Environment { Uuid moduleId; std::shared_ptr parentEnv; std::vector identifiersInScope; std::vector scopes; }; struct SanitizeVisitor::NamedPartialType { std::string name; PartialType type; }; struct SanitizeVisitor::Context { struct ModuleData { std::unordered_map exportedSetByModule; std::shared_ptr environment; std::unique_ptr dependenciesVisitor; }; struct PendingFunction { DeclareFunctionStatement* cloneNode; const DeclareFunctionStatement* node; }; struct UsedExternalData { bool isConditional; }; static constexpr std::size_t ModuleIdSentinel = std::numeric_limits::max(); std::array entryFunctions = {}; std::vector modules; std::vector pendingFunctions; std::vector* currentStatementList = nullptr; std::unordered_map moduleByUuid; std::unordered_map usedBindingIndexes; std::unordered_map declaredExternalVar; std::shared_ptr globalEnv; std::shared_ptr currentEnv; std::shared_ptr moduleEnv; IdentifierList constantValues; IdentifierList functions; IdentifierList aliases; IdentifierList intrinsics; IdentifierList moduleIndices; IdentifierList structs; IdentifierList> types; IdentifierList variableTypes; ModulePtr currentModule; Options options; CurrentFunctionData* currentFunction = nullptr; bool allowUnknownIdentifiers = false; bool inConditionalStatement = false; }; ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) { ModulePtr clone = std::make_shared(module.metadata, module.importedModules); Context currentContext; currentContext.options = options; currentContext.currentModule = clone; m_context = ¤tContext; CallOnExit resetContext([&] { m_context = nullptr; }); PreregisterIndices(module); // Register global env m_context->globalEnv = std::make_shared(); m_context->currentEnv = m_context->globalEnv; RegisterBuiltin(); m_context->moduleEnv = std::make_shared(); m_context->moduleEnv->moduleId = clone->metadata->moduleId; m_context->moduleEnv->parentEnv = m_context->globalEnv; for (std::size_t moduleId = 0; moduleId < clone->importedModules.size(); ++moduleId) { auto importedModuleEnv = std::make_shared(); importedModuleEnv->moduleId = clone->importedModules[moduleId].module->metadata->moduleId; importedModuleEnv->parentEnv = m_context->globalEnv; m_context->currentEnv = importedModuleEnv; auto& importedModule = clone->importedModules[moduleId]; importedModule.module->rootNode = SanitizeInternal(*importedModule.module->rootNode, error); if (!importedModule.module->rootNode) return {}; m_context->moduleByUuid[importedModule.module->metadata->moduleId] = moduleId; auto& moduleData = m_context->modules.emplace_back(); moduleData.environment = std::move(importedModuleEnv); m_context->currentEnv = m_context->globalEnv; RegisterModule(importedModule.identifier, moduleId); } m_context->currentEnv = m_context->moduleEnv; clone->rootNode = SanitizeInternal(*module.rootNode, error); if (!clone->rootNode) return {}; // Remove unused statements of imported modules for (std::size_t moduleId = 0; moduleId < clone->importedModules.size(); ++moduleId) { auto& moduleData = m_context->modules[moduleId]; auto& importedModule = clone->importedModules[moduleId]; if (moduleData.dependenciesVisitor) { moduleData.dependenciesVisitor->Resolve(true); //< allow unknown identifiers since we may be referencing other modules importedModule.module = EliminateUnusedPass(*importedModule.module, moduleData.dependenciesVisitor->GetUsage()); } } return clone; } ExpressionValue SanitizeVisitor::CloneType(const ExpressionValue& exprType) { if (!exprType.HasValue()) return {}; std::optional resolvedType = ResolveTypeExpr(exprType, false, {}); if (!resolvedType.has_value()) return AstCloner::CloneType(exprType); return std::move(resolvedType).value(); } ExpressionPtr SanitizeVisitor::Clone(AccessIdentifierExpression& node) { if (node.identifiers.empty()) throw ShaderLang::AstNoIdentifierError{ node.sourceLocation }; MandatoryExpr(node.expr, node.sourceLocation); // Handle module access (TODO: Add namespace expression?) if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1) { auto& identifierExpr = static_cast(*node.expr); const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier); if (identifierData && identifierData->category == IdentifierCategory::Module) { std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index, node.sourceLocation); const auto& env = *m_context->modules[moduleIndex].environment; identifierData = FindIdentifier(env, node.identifiers.front().identifier); if (identifierData) return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation); } } ExpressionPtr indexedExpr = CloneExpression(node.expr); for (const auto& identifierEntry : node.identifiers) { if (identifierEntry.identifier.empty()) throw ShaderLang::AstEmptyIdentifierError{ identifierEntry.sourceLocation }; const ExpressionType* exprType = GetExpressionType(*indexedExpr); if (!exprType) return AstCloner::Clone(node); //< unresolved type const ExpressionType& resolvedType = ResolveAlias(*exprType); // TODO: Add proper support for methods if (IsSamplerType(resolvedType)) { if (identifierEntry.identifier == "Sample") { // TODO: Add a MethodExpression? auto identifierExpr = std::make_unique(); identifierExpr->expr = std::move(indexedExpr); identifierExpr->identifiers.emplace_back().identifier = identifierEntry.identifier; MethodType methodType; methodType.methodIndex = 0; //< FIXME methodType.objectType = std::make_unique(); methodType.objectType->type = resolvedType; identifierExpr->cachedExpressionType = std::move(methodType); indexedExpr = std::move(identifierExpr); } else throw ShaderLang::CompilerUnknownMethodError{ identifierEntry.sourceLocation }; } else if (IsStructType(resolvedType)) { std::size_t structIndex = ResolveStruct(resolvedType, indexedExpr->sourceLocation); const StructDescription* s = m_context->structs.Retrieve(structIndex, indexedExpr->sourceLocation); // Retrieve member index (not counting disabled fields) Int32 fieldIndex = 0; const StructDescription::StructMember* fieldPtr = nullptr; for (const auto& field : s->members) { if (field.cond.HasValue()) { if (!field.cond.IsResultingValue()) { if (m_context->options.allowPartialSanitization) return AstCloner::Clone(node); //< unresolved throw ShaderLang::CompilerConstantExpressionRequiredError{ field.cond.GetExpression()->sourceLocation }; } else if (!field.cond.GetResultingValue()) continue; } if (field.name == identifierEntry.identifier) { fieldPtr = &field; break; } fieldIndex++; } if (!fieldPtr) { if (s->isConditional) return AstCloner::Clone(node); //< unresolved throw ShaderLang::CompilerUnknownFieldError{ indexedExpr->sourceLocation, identifierEntry.identifier }; } if (m_context->options.useIdentifierAccessesForStructs) { // Use a AccessIdentifierExpression AccessIdentifierExpression* accessIdentifierPtr; if (indexedExpr->GetType() != NodeType::AccessIdentifierExpression) { std::unique_ptr accessIndex = std::make_unique(); accessIndex->sourceLocation = indexedExpr->sourceLocation; accessIndex->expr = std::move(indexedExpr); accessIdentifierPtr = accessIndex.get(); indexedExpr = std::move(accessIndex); } else { accessIdentifierPtr = static_cast(indexedExpr.get()); accessIdentifierPtr->sourceLocation.ExtendToRight(indexedExpr->sourceLocation); } accessIdentifierPtr->cachedExpressionType = ResolveTypeExpr(fieldPtr->type, false, identifierEntry.sourceLocation); auto& newIdentifierEntry = accessIdentifierPtr->identifiers.emplace_back(); newIdentifierEntry.identifier = fieldPtr->name; newIdentifierEntry.sourceLocation = indexedExpr->sourceLocation; } else { // Transform to AccessIndexExpression std::unique_ptr accessIndex = std::make_unique(); accessIndex->sourceLocation = indexedExpr->sourceLocation; accessIndex->expr = std::move(indexedExpr); accessIndex->indices.push_back(ShaderBuilder::Constant(fieldIndex)); accessIndex->cachedExpressionType = ResolveTypeExpr(fieldPtr->type, false, identifierEntry.sourceLocation); indexedExpr = std::move(accessIndex); } } else if (IsPrimitiveType(resolvedType) || IsVectorType(resolvedType)) { // Swizzle expression std::size_t swizzleComponentCount = identifierEntry.identifier.size(); if (swizzleComponentCount > 4) throw ShaderLang::CompilerInvalidSwizzleError{ identifierEntry.sourceLocation }; if (m_context->options.removeScalarSwizzling && IsPrimitiveType(resolvedType)) { for (std::size_t j = 0; j < swizzleComponentCount; ++j) { if (ToSwizzleIndex(identifierEntry.identifier[j], identifierEntry.sourceLocation) != 0) throw ShaderLang::CompilerInvalidScalarSwizzleError{ identifierEntry.sourceLocation }; } if (swizzleComponentCount == 1) continue; //< ignore this swizzle (a.x == a) // Use a Cast expression to replace swizzle indexedExpr = CacheResult(std::move(indexedExpr)); //< Since we are going to use a value multiple times, cache it if required PrimitiveType baseType; if (IsVectorType(resolvedType)) baseType = std::get(resolvedType).type; else baseType = std::get(resolvedType); auto cast = std::make_unique(); cast->targetType = ExpressionType{ VectorType{ swizzleComponentCount, baseType } }; for (std::size_t j = 0; j < swizzleComponentCount; ++j) cast->expressions[j] = CloneExpression(indexedExpr); Validate(*cast); indexedExpr = std::move(cast); } else { auto swizzle = std::make_unique(); swizzle->expression = std::move(indexedExpr); swizzle->componentCount = swizzleComponentCount; for (std::size_t j = 0; j < swizzleComponentCount; ++j) swizzle->components[j] = ToSwizzleIndex(identifierEntry.identifier[j], identifierEntry.sourceLocation); Validate(*swizzle); indexedExpr = std::move(swizzle); } } else throw ShaderLang::CompilerUnexpectedAccessedTypeError{ node.sourceLocation }; } return indexedExpr; } ExpressionPtr SanitizeVisitor::Clone(AccessIndexExpression& node) { MandatoryExpr(node.expr, node.sourceLocation); for (auto& index : node.indices) MandatoryExpr(index, node.sourceLocation); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); // TODO: Handle AccessIndex on structs with m_context->options.useIdentifierAccessesForStructs return clone; } ExpressionPtr SanitizeVisitor::Clone(AliasValueExpression& node) { const Identifier* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(node.aliasId, node.sourceLocation), node.sourceLocation); ExpressionPtr targetExpr = HandleIdentifier(&targetIdentifier->target, node.sourceLocation); if (m_context->options.removeAliases) return targetExpr; AliasType aliasType; aliasType.aliasIndex = node.aliasId; aliasType.targetType = std::make_unique(); aliasType.targetType->type = *targetExpr->cachedExpressionType; auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); clone->cachedExpressionType = std::move(aliasType); return clone; } ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node) { MandatoryExpr(node.left, node.sourceLocation); MandatoryExpr(node.right, node.sourceLocation); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); return clone; } ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); return clone; } ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node) { ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction, node.sourceLocation)); const ExpressionType* targetExprType = GetExpressionType(*targetExpr); if (!targetExprType) return AstCloner::Clone(node); //< unresolved type const ExpressionType& resolvedType = ResolveAlias(*targetExprType); if (IsFunctionType(resolvedType)) { if (!m_context->currentFunction) throw ShaderLang::CompilerFunctionCallOutsideOfFunctionError{ node.sourceLocation }; std::size_t targetFuncIndex; if (targetExpr->GetType() == NodeType::FunctionExpression) targetFuncIndex = static_cast(*targetExpr).funcId; else if (targetExpr->GetType() == NodeType::AliasValueExpression) { const auto& alias = static_cast(*targetExpr); const Identifier* aliasIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId, node.sourceLocation), targetExpr->sourceLocation); if (aliasIdentifier->target.category != IdentifierCategory::Function) throw ShaderLang::CompilerExpectedFunctionError{ targetExpr->sourceLocation }; targetFuncIndex = aliasIdentifier->target.index; } else throw ShaderLang::CompilerExpectedFunctionError{ targetExpr->sourceLocation }; auto clone = std::make_unique(); clone->sourceLocation = node.sourceLocation; clone->targetFunction = std::move(targetExpr); clone->parameters.reserve(node.parameters.size()); for (const auto& parameter : node.parameters) clone->parameters.push_back(CloneExpression(parameter)); m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex); Validate(*clone); return clone; } else if (IsIntrinsicFunctionType(resolvedType)) { if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression) throw ShaderLang::CompilerExpectedIntrinsicFunctionError{ targetExpr->sourceLocation }; std::size_t targetIntrinsicId = static_cast(*targetExpr).intrinsicId; std::vector parameters; parameters.reserve(node.parameters.size()); for (const auto& param : node.parameters) parameters.push_back(CloneExpression(param)); auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics.Retrieve(targetIntrinsicId, node.sourceLocation), std::move(parameters)); intrinsic->sourceLocation = node.sourceLocation; Validate(*intrinsic); return intrinsic; } else if (IsMethodType(resolvedType)) { const MethodType& methodType = std::get(resolvedType); std::vector parameters; parameters.reserve(node.parameters.size() + 1); // TODO: Add MethodExpression assert(targetExpr->GetType() == NodeType::AccessIdentifierExpression); parameters.push_back(std::move(static_cast(*targetExpr).expr)); for (const auto& param : node.parameters) parameters.push_back(CloneExpression(param)); assert(IsSamplerType(methodType.objectType->type) && methodType.methodIndex == 0); auto intrinsic = ShaderBuilder::Intrinsic(IntrinsicType::SampleTexture, std::move(parameters)); intrinsic->sourceLocation = node.sourceLocation; Validate(*intrinsic); return intrinsic; } else { // Calling a type - vec3[f32](0.0, 1.0, 2.0) - it's a cast auto clone = std::make_unique(); clone->sourceLocation = node.sourceLocation; clone->targetType = *targetExprType; if (node.parameters.size() > clone->expressions.size()) throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation }; for (std::size_t i = 0; i < node.parameters.size(); ++i) clone->expressions[i] = CloneExpression(node.parameters[i]); Validate(*clone); return Clone(*clone); //< Necessary because cast has to be modified (FIXME) } } ExpressionPtr SanitizeVisitor::Clone(CastExpression& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (Validate(*clone) == ValidationResult::Unresolved) return clone; //< unresolved const ExpressionType& targetType = clone->targetType.GetResultingValue(); if (m_context->options.removeMatrixCast && IsMatrixType(targetType)) { const MatrixType& targetMatrixType = std::get(targetType); const ExpressionType& frontExprType = ResolveAlias(GetExpressionTypeSecure(*clone->expressions.front())); bool isMatrixCast = IsMatrixType(frontExprType); if (isMatrixCast && std::get(frontExprType) == targetMatrixType) { // Nothing to do return std::move(clone->expressions.front()); } auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", targetType); //< Validation will prevent name-clash if required variableDeclaration->sourceLocation = node.sourceLocation; Validate(*variableDeclaration); std::size_t variableIndex = *variableDeclaration->varIndex; m_context->currentStatementList->emplace_back(std::move(variableDeclaration)); for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i) { // temp[i] auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, targetType), ShaderBuilder::Constant(UInt32(i))); columnExpr->sourceLocation = node.sourceLocation; Validate(*columnExpr); // vector expression ExpressionPtr vectorExpr; std::size_t vectorComponentCount; if (isMatrixCast) { // fromMatrix[i] auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::Constant(UInt32(i))); matrixColumnExpr->sourceLocation = node.sourceLocation; Validate(*matrixColumnExpr); vectorExpr = std::move(matrixColumnExpr); vectorComponentCount = std::get(frontExprType).rowCount; } else { // parameter #i vectorExpr = std::move(clone->expressions[i]); vectorComponentCount = std::get(ResolveAlias(GetExpressionTypeSecure(*vectorExpr))).componentCount; } // cast expression (turn fromMatrix[i] to vec3[f32](fromMatrix[i])) ExpressionPtr castExpr; if (vectorComponentCount != targetMatrixType.rowCount) { CastExpressionPtr vecCast; if (vectorComponentCount < targetMatrixType.rowCount) { std::array expressions; expressions[0] = std::move(vectorExpr); for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j) expressions[j + 1] = ShaderBuilder::Constant(ExpressionType{ targetMatrixType.type }, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal vecCast = ShaderBuilder::Cast(ExpressionType{ VectorType{ targetMatrixType.rowCount, targetMatrixType.type } }, std::move(expressions)); vecCast->sourceLocation = node.sourceLocation; Validate(*vecCast); castExpr = std::move(vecCast); } else { std::array swizzleComponents; std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0); auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount); swizzleExpr->sourceLocation = node.sourceLocation; Validate(*swizzleExpr); castExpr = std::move(swizzleExpr); } } else castExpr = std::move(vectorExpr); // temp[i] = castExpr auto assignExpr = ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr)); assignExpr->sourceLocation = node.sourceLocation; m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(std::move(assignExpr))); } auto varExpr = ShaderBuilder::Variable(variableIndex, targetType); varExpr->sourceLocation = node.sourceLocation; return varExpr; } return clone; } ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node) { MandatoryExpr(node.condition, node.sourceLocation); MandatoryExpr(node.truePath, node.sourceLocation); MandatoryExpr(node.falsePath, node.sourceLocation); ExpressionPtr cloneCondition = AstCloner::Clone(*node.condition); std::optional conditionValue = ComputeConstantValue(*cloneCondition); if (!conditionValue.has_value()) { // Unresolvable condition return AstCloner::Clone(node); } if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) throw ShaderLang::CompilerConditionExpectedBoolError{ cloneCondition->sourceLocation, ToString(GetConstantType(*conditionValue), cloneCondition->sourceLocation) }; if (std::get(*conditionValue)) return AstCloner::Clone(*node.truePath); else return AstCloner::Clone(*node.falsePath); } ExpressionPtr SanitizeVisitor::Clone(ConstantValueExpression& node) { if (std::holds_alternative(node.value)) throw ShaderLang::CompilerConstantExpectedValueError{ node.sourceLocation }; auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); clone->cachedExpressionType = GetConstantType(clone->value); return clone; } ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { const ConstantValue* value = m_context->constantValues.TryRetrieve(node.constantId, node.sourceLocation); if (!value) { if (!m_context->options.allowPartialSanitization) throw ShaderLang::AstInvalidConstantIndexError{ node.sourceLocation, node.constantId }; return AstCloner::Clone(node); //< unresolved } // Replace by constant value auto constant = ShaderBuilder::Constant(*value); constant->cachedExpressionType = GetConstantType(constant->value); constant->sourceLocation = node.sourceLocation; return constant; } ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node) { assert(m_context); const IdentifierData* identifierData = FindIdentifier(node.identifier); if (!identifierData) { if (m_context->allowUnknownIdentifiers) return AstCloner::Clone(node); throw ShaderLang::CompilerUnknownIdentifierError{ node.sourceLocation, node.identifier }; } if (identifierData->category == IdentifierCategory::Unresolved) return AstCloner::Clone(node); return HandleIdentifier(identifierData, node.sourceLocation); } ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); return clone; } ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node) { auto expression = CloneExpression(MandatoryExpr(node.expression, node.sourceLocation)); const ExpressionType* exprType = GetExpressionType(*expression); if (!exprType) { auto swizzleExpr = ShaderBuilder::Swizzle(std::move(expression), node.components, node.componentCount); //< unresolved swizzleExpr->cachedExpressionType = node.cachedExpressionType; swizzleExpr->sourceLocation = node.sourceLocation; return swizzleExpr; } const ExpressionType& resolvedExprType = ResolveAlias(*exprType); if (m_context->options.removeScalarSwizzling && IsPrimitiveType(resolvedExprType)) { for (std::size_t i = 0; i < node.componentCount; ++i) { if (node.components[i] != 0) throw ShaderLang::CompilerInvalidScalarSwizzleError{ node.sourceLocation }; } if (node.componentCount == 1) return expression; //< ignore this swizzle (a.x == a) // Use a Cast expression to replace swizzle expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required PrimitiveType baseType; if (IsVectorType(resolvedExprType)) baseType = std::get(resolvedExprType).type; else baseType = std::get(resolvedExprType); auto cast = std::make_unique(); cast->sourceLocation = node.sourceLocation; cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } }; for (std::size_t j = 0; j < node.componentCount; ++j) cast->expressions[j] = CloneExpression(expression); Validate(*cast); return cast; } else { auto clone = std::make_unique(); clone->componentCount = node.componentCount; clone->components = node.components; clone->expression = std::move(expression); clone->sourceLocation = node.sourceLocation; Validate(*clone); return clone; } } ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); return clone; } ExpressionPtr SanitizeVisitor::Clone(VariableValueExpression& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); return clone; } StatementPtr SanitizeVisitor::Clone(BranchStatement& node) { if (node.isConst) { // Evaluate every condition at compilation and select the right statement for (auto& cond : node.condStatements) { MandatoryExpr(cond.condition, node.sourceLocation); std::optional conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition)); if (!conditionValue.has_value()) return AstCloner::Clone(node); //< Unresolvable condition if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) throw ShaderLang::CompilerConditionExpectedBoolError{ cond.condition->sourceLocation, ToString(GetConstantType(*conditionValue), cond.condition->sourceLocation) }; if (std::get(*conditionValue)) return Unscope(AstCloner::Clone(*cond.statement)); } // Every condition failed, fallback to else if any if (node.elseStatement) return Unscope(AstCloner::Clone(*node.elseStatement)); else return ShaderBuilder::NoOp(); } auto clone = std::make_unique(); clone->condStatements.reserve(node.condStatements.size()); if (!m_context->currentFunction) throw ShaderLang::CompilerBranchOutsideOfFunctionError{ node.sourceLocation }; BranchStatement* root = clone.get(); for (std::size_t condIndex = 0; condIndex < node.condStatements.size(); ++condIndex) { auto& cond = node.condStatements[condIndex]; PushScope(); auto BuildCondStatement = [&](BranchStatement::ConditionalStatement& condStatement) { condStatement.condition = CloneExpression(MandatoryExpr(cond.condition, node.sourceLocation)); const ExpressionType* condType = GetExpressionType(*condStatement.condition); if (!condType) return ValidationResult::Unresolved; if (!IsPrimitiveType(*condType) || std::get(*condType) != PrimitiveType::Boolean) throw ShaderLang::CompilerConditionExpectedBoolError{ condStatement.condition->sourceLocation, ToString(*condType, condStatement.condition->sourceLocation)}; condStatement.statement = CloneStatement(MandatoryStatement(cond.statement, node.sourceLocation)); return ValidationResult::Validated; }; if (m_context->options.splitMultipleBranches && condIndex > 0) { auto currentBranch = std::make_unique(); if (BuildCondStatement(currentBranch->condStatements.emplace_back()) == ValidationResult::Unresolved) return AstCloner::Clone(node); root->elseStatement = std::move(currentBranch); root = static_cast(root->elseStatement.get()); } else { if (BuildCondStatement(clone->condStatements.emplace_back()) == ValidationResult::Unresolved) return AstCloner::Clone(node); } PopScope(); } if (node.elseStatement) { PushScope(); root->elseStatement = CloneStatement(node.elseStatement); PopScope(); } return clone; } StatementPtr SanitizeVisitor::Clone(ConditionalStatement& node) { MandatoryExpr(node.condition, node.sourceLocation); MandatoryStatement(node.statement, node.sourceLocation); ExpressionPtr cloneCondition = AstCloner::Clone(*node.condition); std::optional conditionValue = ComputeConstantValue(*cloneCondition); bool wasInConditionalStatement = m_context->inConditionalStatement; m_context->inConditionalStatement = true; CallOnExit restoreCond([=] { m_context->inConditionalStatement = wasInConditionalStatement; }); if (!conditionValue.has_value()) { // Unresolvable condition auto condStatement = ShaderBuilder::ConditionalStatement(std::move(cloneCondition), AstCloner::Clone(*node.statement)); condStatement->sourceLocation = node.sourceLocation; return condStatement; } if (GetConstantType(*conditionValue) != ExpressionType{ PrimitiveType::Boolean }) throw ShaderLang::CompilerConditionExpectedBoolError{ cloneCondition->sourceLocation, ToString(GetConstantType(*conditionValue), cloneCondition->sourceLocation) }; if (std::get(*conditionValue)) return AstCloner::Clone(*node.statement); else return ShaderBuilder::NoOp(); } StatementPtr SanitizeVisitor::Clone(DeclareAliasStatement& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); if (m_context->options.removeAliases) return ShaderBuilder::NoOp(); return clone; } StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node) { auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (!clone->expression) throw ShaderLang::CompilerConstMissingExpressionError{ node.sourceLocation }; clone->expression = PropagateConstants(*clone->expression); if (clone->expression->GetType() != NodeType::ConstantValueExpression) { if (!m_context->options.allowPartialSanitization) throw ShaderLang::CompilerConstantExpressionRequiredError{ clone->expression->sourceLocation }; clone->constIndex = RegisterConstant(clone->name, std::nullopt, clone->constIndex, node.sourceLocation); return clone; } const ConstantValue& value = static_cast(*clone->expression).value; ExpressionType expressionType = GetConstantType(value); std::optional constType = ResolveTypeExpr(clone->type, true, node.sourceLocation); if (clone->type.HasValue() && constType.has_value() && *constType != ResolveAlias(expressionType)) throw ShaderLang::CompilerVarDeclarationTypeUnmatchingError{ clone->expression->sourceLocation, ToString(expressionType, clone->expression->sourceLocation), ToString(*constType, node.sourceLocation) }; clone->type = expressionType; clone->constIndex = RegisterConstant(clone->name, value, clone->constIndex, node.sourceLocation); if (m_context->options.removeConstDeclaration) return ShaderBuilder::NoOp(); return clone; } StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node) { assert(m_context); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); std::optional defaultBlockSet = 0; if (clone->bindingSet.HasValue()) { if (ComputeExprValue(clone->bindingSet, node.sourceLocation) == ValidationResult::Validated) defaultBlockSet = clone->bindingSet.GetResultingValue(); else defaultBlockSet.reset(); //< Unresolved value } for (auto& extVar : clone->externalVars) { if (!extVar.bindingIndex.HasValue()) throw ShaderLang::CompilerExtMissingBindingIndexError{ extVar.sourceLocation }; if (extVar.bindingSet.HasValue()) ComputeExprValue(extVar.bindingSet, node.sourceLocation); else if (defaultBlockSet) extVar.bindingSet = *defaultBlockSet; ComputeExprValue(extVar.bindingIndex, node.sourceLocation); Context::UsedExternalData usedBindingData; usedBindingData.isConditional = m_context->inConditionalStatement; if (extVar.bindingSet.IsResultingValue() && extVar.bindingIndex.IsResultingValue()) { UInt64 bindingSet = extVar.bindingSet.GetResultingValue(); UInt64 bindingIndex = extVar.bindingIndex.GetResultingValue(); UInt64 bindingKey = bindingSet << 32 | bindingIndex; if (auto it = m_context->usedBindingIndexes.find(bindingKey); it != m_context->usedBindingIndexes.end()) { if (!it->second.isConditional || !usedBindingData.isConditional) throw ShaderLang::CompilerExtBindingAlreadyUsedError{ extVar.sourceLocation, UInt32(bindingSet), UInt32(bindingIndex) }; } m_context->usedBindingIndexes.emplace(bindingKey, usedBindingData); } if (auto it = m_context->declaredExternalVar.find(extVar.name); it != m_context->declaredExternalVar.end()) { if (!it->second.isConditional || !usedBindingData.isConditional) throw ShaderLang::CompilerExtAlreadyDeclaredError{ extVar.sourceLocation, extVar.name }; } m_context->declaredExternalVar.emplace(extVar.name, usedBindingData); std::optional resolvedType = ResolveTypeExpr(extVar.type, false, node.sourceLocation); if (!resolvedType.has_value()) { RegisterUnresolved(extVar.name); continue; } const ExpressionType& targetType = ResolveAlias(*resolvedType); ExpressionType varType; if (IsUniformType(targetType)) varType = std::get(targetType).containedType; else if (IsSamplerType(targetType)) varType = targetType; else throw ShaderLang::CompilerExtTypeNotAllowedError{ extVar.sourceLocation, extVar.name, ToString(*resolvedType, extVar.sourceLocation) }; extVar.type = std::move(resolvedType).value(); extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex, extVar.sourceLocation); SanitizeIdentifier(extVar.name); } return clone; } StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node) { if (m_context->currentFunction) throw ShaderLang::CompilerFunctionDeclarationInsideFunctionError{ node.sourceLocation }; auto clone = std::make_unique(); clone->name = node.name; clone->parameters.reserve(node.parameters.size()); for (auto& parameter : node.parameters) { auto& cloneParam = clone->parameters.emplace_back(); cloneParam.name = parameter.name; cloneParam.type = CloneType(parameter.type); cloneParam.varIndex = parameter.varIndex; cloneParam.sourceLocation = parameter.sourceLocation; } if (node.returnType.HasValue()) clone->returnType = CloneType(node.returnType); else clone->returnType = ExpressionType{ NoType{} }; if (node.depthWrite.HasValue()) ComputeExprValue(node.depthWrite, clone->depthWrite, node.sourceLocation); if (node.earlyFragmentTests.HasValue()) ComputeExprValue(node.earlyFragmentTests, clone->earlyFragmentTests, node.sourceLocation); if (node.entryStage.HasValue()) ComputeExprValue(node.entryStage, clone->entryStage, node.sourceLocation); if (node.isExported.HasValue()) ComputeExprValue(node.isExported, clone->isExported, node.sourceLocation); if (clone->entryStage.IsResultingValue()) { ShaderStageType stageType = clone->entryStage.GetResultingValue(); if (!m_context->options.allowPartialSanitization) { if (m_context->entryFunctions[UnderlyingCast(stageType)]) throw ShaderLang::CompilerEntryPointAlreadyDefinedError{ clone->sourceLocation, stageType }; m_context->entryFunctions[UnderlyingCast(stageType)] = &node; } if (node.parameters.size() > 1) throw ShaderLang::CompilerEntryFunctionParameterError{ node.parameters[1].sourceLocation }; if (!node.parameters.empty()) { auto& parameter = node.parameters.front(); if (parameter.type.IsResultingValue()) { if (!IsStructType(ResolveAlias(parameter.type.GetResultingValue()))) throw ShaderLang::CompilerEntryFunctionParameterError{ parameter.sourceLocation }; } } if (stageType != ShaderStageType::Fragment) { if (node.depthWrite.HasValue()) throw ShaderLang::CompilerDepthWriteAttributeError{ node.sourceLocation }; if (node.earlyFragmentTests.HasValue()) throw ShaderLang::CompilerEarlyFragmentTestsAttributeError{ node.sourceLocation }; } } // Function content is resolved in a second pass auto& pendingFunc = m_context->pendingFunctions.emplace_back(); pendingFunc.cloneNode = clone.get(); pendingFunc.node = &node; if (clone->earlyFragmentTests.HasValue() && clone->earlyFragmentTests.GetResultingValue()) { //TODO: warning and disable early fragment tests throw ShaderLang::CompilerDiscardEarlyFragmentTestsError{ node.sourceLocation }; } FunctionData funcData; funcData.node = clone.get(); //< update function node std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData), node.funcIndex, node.sourceLocation); clone->funcIndex = funcIndex; SanitizeIdentifier(clone->name); return clone; } StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node) { if (m_context->currentFunction) throw ShaderLang::CompilerOptionDeclarationInsideFunctionError{ node.sourceLocation }; auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (clone->optName.empty()) throw ShaderLang::AstEmptyIdentifierError{ node.sourceLocation }; std::optional resolvedOptionType = ResolveTypeExpr(clone->optType, false, node.sourceLocation); if (!resolvedOptionType) { clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex, node.sourceLocation); return clone; } ExpressionType resolvedType = ResolveType(*resolvedOptionType, false, node.sourceLocation); const ExpressionType& targetType = ResolveAlias(resolvedType); if (clone->defaultValue) { const ExpressionType* defaultValueType = GetExpressionType(*clone->defaultValue); if (!defaultValueType) { clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex, node.sourceLocation); return clone; //< unresolved } if (targetType != *defaultValueType) throw ShaderLang::CompilerVarDeclarationTypeUnmatchingError{ node.sourceLocation }; } clone->optType = std::move(resolvedType); UInt32 optionHash = CRC32(reinterpret_cast(clone->optName.data()), clone->optName.size()); if (auto optionValueIt = m_context->options.optionValues.find(optionHash); optionValueIt != m_context->options.optionValues.end()) clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, node.optIndex, node.sourceLocation); else { if (m_context->options.allowPartialSanitization) { // Partial sanitization, we cannot give a value to this option clone->optIndex = RegisterConstant(clone->optName, std::nullopt, clone->optIndex, node.sourceLocation); } else { if (!clone->defaultValue) throw ShaderLang::CompilerMissingOptionValueError{ node.sourceLocation, clone->optName }; clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), node.optIndex, node.sourceLocation); } } if (m_context->options.removeOptionDeclaration) return ShaderBuilder::NoOp(); return clone; } StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node) { if (m_context->currentFunction) throw ShaderLang::CompilerStructDeclarationInsideFunctionError{ node.sourceLocation }; auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (clone->isExported.HasValue()) ComputeExprValue(clone->isExported, node.sourceLocation); if (clone->description.layout.HasValue()) ComputeExprValue(clone->description.layout, node.sourceLocation); std::unordered_set declaredMembers; for (auto& member : clone->description.members) { if (member.cond.HasValue()) { ComputeExprValue(member.cond, member.sourceLocation); if (member.cond.IsResultingValue() && !member.cond.GetResultingValue()) continue; } if (member.builtin.HasValue()) ComputeExprValue(member.builtin, member.sourceLocation); if (member.locationIndex.HasValue()) ComputeExprValue(member.locationIndex, member.sourceLocation); if (member.builtin.HasValue() && member.locationIndex.HasValue()) throw ShaderLang::CompilerStructFieldBuiltinLocationError{ member.sourceLocation }; if (declaredMembers.find(member.name) != declaredMembers.end()) { if ((!member.cond.HasValue() || !member.cond.IsResultingValue()) && !m_context->options.allowPartialSanitization) throw ShaderLang::CompilerStructFieldMultipleError{ member.sourceLocation, member.name }; } declaredMembers.insert(member.name); if (member.type.HasValue() && member.type.IsExpression()) { assert(m_context->options.allowPartialSanitization); continue; } ExpressionType resolvedType = member.type.GetResultingValue(); if (clone->description.layout.IsResultingValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140) { const ExpressionType& targetType = ResolveAlias(member.type.GetResultingValue()); if (IsPrimitiveType(targetType) && std::get(targetType) == PrimitiveType::Boolean) throw ShaderLang::CompilerStructLayoutTypeNotAllowedError{ member.sourceLocation, "bool", "std140" }; else if (IsStructType(targetType)) { std::size_t structIndex = std::get(targetType).structIndex; const StructDescription* desc = m_context->structs.Retrieve(structIndex, member.sourceLocation); if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) throw ShaderLang::CompilerStructLayoutInnerMismatchError{ member.sourceLocation, "std140", "" }; } } } clone->description.isConditional = m_context->inConditionalStatement; clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex, clone->sourceLocation); SanitizeIdentifier(clone->description.name); return clone; } StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node) { if (!m_context->currentFunction) throw ShaderLang::CompilerVarDeclarationOutsideOfFunctionError{ node.sourceLocation }; auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); Validate(*clone); return clone; } StatementPtr SanitizeVisitor::Clone(DiscardStatement& node) { if (!m_context->currentFunction) throw ShaderLang::CompilerDiscardOutsideOfFunctionError{ node.sourceLocation }; m_context->currentFunction->flags |= FunctionFlag::DoesDiscard; return AstCloner::Clone(node); } StatementPtr SanitizeVisitor::Clone(ExpressionStatement& node) { MandatoryExpr(node.expression, node.sourceLocation); return AstCloner::Clone(node); } StatementPtr SanitizeVisitor::Clone(ForStatement& node) { if (node.varName.empty()) throw ShaderLang::AstEmptyIdentifierError{ node.sourceLocation }; auto fromExpr = CloneExpression(MandatoryExpr(node.fromExpr, node.sourceLocation)); auto stepExpr = CloneExpression(node.stepExpr); auto toExpr = CloneExpression(MandatoryExpr(node.toExpr, node.sourceLocation)); MandatoryStatement(node.statement, node.sourceLocation); const ExpressionType* fromExprType = GetExpressionType(*fromExpr); const ExpressionType* toExprType = GetExpressionType(*toExpr); ExpressionValue unrollValue; auto CloneFor = [&] { auto clone = std::make_unique(); clone->fromExpr = std::move(fromExpr); clone->stepExpr = std::move(stepExpr); clone->toExpr = std::move(toExpr); clone->varName = node.varName; clone->unroll = std::move(unrollValue); PushScope(); { if (fromExprType) clone->varIndex = RegisterVariable(node.varName, *fromExprType, node.varIndex, node.sourceLocation); else { RegisterUnresolved(node.varName); clone->varIndex = node.varIndex; //< preserve var index, if set } clone->statement = CloneStatement(node.statement); } PopScope(); SanitizeIdentifier(clone->varName); return clone; }; if (node.unroll.HasValue() && ComputeExprValue(node.unroll, unrollValue, node.sourceLocation) == ValidationResult::Unresolved) return CloneFor(); //< unresolved unroll if (!fromExprType || !toExprType) return CloneFor(); //< unresolved from/to type const ExpressionType& resolvedFromExprType = ResolveAlias(*fromExprType); if (!IsPrimitiveType(resolvedFromExprType)) throw ShaderLang::CompilerForFromTypeExpectIntegerTypeError{ fromExpr->sourceLocation, ToString(*fromExprType, fromExpr->sourceLocation) }; PrimitiveType counterType = std::get(resolvedFromExprType); if (counterType != PrimitiveType::Int32 && counterType != PrimitiveType::UInt32) throw ShaderLang::CompilerForFromTypeExpectIntegerTypeError{ fromExpr->sourceLocation, ToString(*fromExprType, fromExpr->sourceLocation) }; const ExpressionType& resolvedToExprType = ResolveAlias(*toExprType); if (resolvedToExprType != resolvedFromExprType) throw ShaderLang::CompilerForToUnmatchingTypeError{ toExpr->sourceLocation, ToString(*toExprType, toExpr->sourceLocation), ToString(*fromExprType, fromExpr->sourceLocation) }; if (stepExpr) { const ExpressionType* stepExprType = GetExpressionType(*stepExpr); if (!stepExprType) return CloneFor(); //< unresolved step type const ExpressionType& resolvedStepExprType = ResolveAlias(*stepExprType); if (resolvedStepExprType != resolvedFromExprType) throw ShaderLang::CompilerForStepUnmatchingTypeError{ stepExpr->sourceLocation, ToString(*stepExprType, stepExpr->sourceLocation), ToString(*fromExprType, fromExpr->sourceLocation) }; } if (unrollValue.HasValue()) { assert(unrollValue.IsResultingValue()); if (unrollValue.GetResultingValue() == LoopUnroll::Always) { std::optional fromValue = ComputeConstantValue(*fromExpr); std::optional toValue = ComputeConstantValue(*toExpr); if (!fromValue.has_value() || !toValue.has_value()) return CloneFor(); //< can't resolve step value std::optional stepValue; if (stepExpr) { stepValue = ComputeConstantValue(*stepExpr); if (!stepValue.has_value()) return CloneFor(); //< can't resolve step value } auto multi = std::make_unique(); multi->sourceLocation = node.sourceLocation; auto Unroll = [&](auto dummy) { using T = std::decay_t; T counter = std::get(*fromValue); T to = std::get(*toValue); T step = (stepExpr) ? std::get(*stepValue) : T(1); for (; counter < to; counter += step) { PushScope(); auto innerMulti = std::make_unique(); innerMulti->sourceLocation = node.sourceLocation; auto constant = ShaderBuilder::Constant(counter); constant->sourceLocation = node.sourceLocation; auto var = ShaderBuilder::DeclareVariable(node.varName, std::move(constant)); var->sourceLocation = node.sourceLocation; Validate(*var); innerMulti->statements.emplace_back(std::move(var)); innerMulti->statements.emplace_back(Unscope(CloneStatement(node.statement))); multi->statements.emplace_back(ShaderBuilder::Scoped(std::move(innerMulti))); PopScope(); } }; switch (counterType) { case PrimitiveType::Int32: Unroll(Int32{}); break; case PrimitiveType::UInt32: Unroll(UInt32{}); break; default: throw ShaderLang::AstInternalError{ node.sourceLocation, "unexpected counter type " }; } return multi; } } if (m_context->options.reduceLoopsToWhile) { PushScope(); auto multi = std::make_unique(); // Counter variable auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr)); counterVariable->sourceLocation = node.sourceLocation; counterVariable->varIndex = node.varIndex; Validate(*counterVariable); std::size_t counterVarIndex = counterVariable->varIndex.value(); multi->statements.emplace_back(std::move(counterVariable)); // Target variable auto targetVariable = ShaderBuilder::DeclareVariable("to", std::move(toExpr)); targetVariable->sourceLocation = node.sourceLocation; Validate(*targetVariable); std::size_t targetVarIndex = targetVariable->varIndex.value(); multi->statements.emplace_back(std::move(targetVariable)); // Step variable std::optional stepVarIndex; if (stepExpr) { auto stepVariable = ShaderBuilder::DeclareVariable("step", std::move(stepExpr)); stepVariable->sourceLocation = node.sourceLocation; Validate(*stepVariable); stepVarIndex = stepVariable->varIndex; multi->statements.emplace_back(std::move(stepVariable)); } // While auto whileStatement = std::make_unique(); whileStatement->unroll = std::move(unrollValue); // While condition auto conditionCounterVariable = ShaderBuilder::Variable(counterVarIndex, counterType); conditionCounterVariable->sourceLocation = node.sourceLocation; auto conditionTargetVariable = ShaderBuilder::Variable(targetVarIndex, counterType); conditionTargetVariable->sourceLocation = node.sourceLocation; auto condition = ShaderBuilder::Binary(BinaryType::CompLt, std::move(conditionCounterVariable), std::move(conditionTargetVariable)); condition->sourceLocation = node.sourceLocation; Validate(*condition); whileStatement->condition = std::move(condition); // While body auto body = std::make_unique(); body->statements.reserve(2); body->statements.emplace_back(Unscope(CloneStatement(node.statement))); ExpressionPtr incrExpr; if (stepVarIndex) incrExpr = ShaderBuilder::Variable(*stepVarIndex, counterType); else incrExpr = (counterType == PrimitiveType::Int32) ? ShaderBuilder::Constant(1) : ShaderBuilder::Constant(1u); auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType), std::move(incrExpr)); incrCounter->sourceLocation = node.sourceLocation; Validate(*incrCounter); body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); whileStatement->body = std::move(body); multi->statements.emplace_back(std::move(whileStatement)); PopScope(); return multi; } else return CloneFor(); } StatementPtr SanitizeVisitor::Clone(ForEachStatement& node) { auto expr = CloneExpression(MandatoryExpr(node.expression, node.sourceLocation)); if (node.varName.empty()) throw ShaderLang::AstEmptyIdentifierError{ node.sourceLocation }; const ExpressionType* exprType = GetExpressionType(*expr); if (!exprType) return AstCloner::Clone(node); //< unresolved expression type const ExpressionType& resolvedExprType = ResolveAlias(*exprType); ExpressionType innerType; if (IsArrayType(resolvedExprType)) { const ArrayType& arrayType = std::get(resolvedExprType); innerType = arrayType.containedType->type; } else throw ShaderLang::CompilerForEachUnsupportedTypeError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) }; ExpressionValue unrollValue; if (node.unroll.HasValue()) { if (ComputeExprValue(node.unroll, unrollValue, node.sourceLocation) == ValidationResult::Unresolved) return AstCloner::Clone(node); //< unresolved unroll type if (unrollValue.GetResultingValue() == LoopUnroll::Always) { PushScope(); // Repeat code auto multi = std::make_unique(); multi->sourceLocation = node.sourceLocation; if (IsArrayType(resolvedExprType)) { const ArrayType& arrayType = std::get(resolvedExprType); for (UInt32 i = 0; i < arrayType.length; ++i) { PushScope(); auto innerMulti = std::make_unique(); innerMulti->sourceLocation = node.sourceLocation; auto accessIndex = ShaderBuilder::AccessIndex(CloneExpression(expr), ShaderBuilder::Constant(i)); Validate(*accessIndex); auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex)); Validate(*elementVariable); innerMulti->statements.emplace_back(std::move(elementVariable)); innerMulti->statements.emplace_back(Unscope(CloneStatement(node.statement))); multi->statements.emplace_back(ShaderBuilder::Scoped(std::move(innerMulti))); PopScope(); } } PopScope(); return multi; } } if (m_context->options.reduceLoopsToWhile) { PushScope(); auto multi = std::make_unique(); if (IsArrayType(resolvedExprType)) { const ArrayType& arrayType = std::get(resolvedExprType); multi->statements.reserve(2); // Counter variable auto counterVariable = ShaderBuilder::DeclareVariable("i", ShaderBuilder::Constant(0u)); Validate(*counterVariable); std::size_t counterVarIndex = counterVariable->varIndex.value(); multi->statements.emplace_back(std::move(counterVariable)); auto whileStatement = std::make_unique(); whileStatement->unroll = std::move(unrollValue); // While condition auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(arrayType.length)); Validate(*condition); whileStatement->condition = std::move(condition); // While body auto body = std::make_unique(); body->statements.reserve(3); auto accessIndex = ShaderBuilder::AccessIndex(std::move(expr), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32)); Validate(*accessIndex); auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex)); elementVariable->varIndex = node.varIndex; //< Preserve var index Validate(*elementVariable); body->statements.emplace_back(std::move(elementVariable)); body->statements.emplace_back(Unscope(CloneStatement(node.statement))); auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32), ShaderBuilder::Constant(1u)); Validate(*incrCounter); body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); whileStatement->body = std::move(body); multi->statements.emplace_back(std::move(whileStatement)); } PopScope(); return multi; } else { auto clone = std::make_unique(); clone->expression = std::move(expr); clone->varName = node.varName; clone->unroll = std::move(unrollValue); clone->sourceLocation = node.sourceLocation; PushScope(); { clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex, node.sourceLocation); clone->statement = CloneStatement(node.statement); } PopScope(); SanitizeIdentifier(clone->varName); return clone; } } StatementPtr SanitizeVisitor::Clone(ImportStatement& node) { if (!m_context->options.moduleResolver) { if (!m_context->options.allowPartialSanitization) throw ShaderLang::CompilerNoModuleResolverError{ node.sourceLocation }; // when partially sanitizing, importing a whole module could register any identifier, so at this point we can't see unknown identifiers as errors m_context->allowUnknownIdentifiers = true; return StaticUniquePointerCast(AstCloner::Clone(node)); } ModulePtr targetModule = m_context->options.moduleResolver->Resolve(node.moduleName); if (!targetModule) throw ShaderLang::CompilerModuleNotFoundError{ node.sourceLocation, node.moduleName }; std::size_t moduleIndex; const Uuid& moduleUuid = targetModule->metadata->moduleId; auto it = m_context->moduleByUuid.find(moduleUuid); if (it == m_context->moduleByUuid.end()) { m_context->moduleByUuid[moduleUuid] = Context::ModuleIdSentinel; // Generate module identifier (based on UUID) const auto& moduleUuidBytes = moduleUuid.ToArray(); SHA256Hash hasher; hasher.Begin(); hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size()); hasher.End(); std::string identifier = "_" + hasher.End().ToHex().substr(0, 8); // Load new module auto moduleEnvironment = std::make_shared(); moduleEnvironment->parentEnv = m_context->globalEnv; auto previousEnv = m_context->currentEnv; m_context->currentEnv = moduleEnvironment; ModulePtr sanitizedModule = std::make_shared(targetModule->metadata); // Remap already used indices IndexRemapperVisitor::Options indexCallbacks; indexCallbacks.aliasIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->aliases.RegisterNewIndex(true); }; indexCallbacks.constIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->constantValues.RegisterNewIndex(true); }; indexCallbacks.funcIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->functions.RegisterNewIndex(true); }; indexCallbacks.structIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->structs.RegisterNewIndex(true); }; indexCallbacks.varIndexGenerator = [this](std::size_t /*previousIndex*/) { return m_context->variableTypes.RegisterNewIndex(true); }; indexCallbacks.forceIndexGeneration = true; sanitizedModule->rootNode = StaticUniquePointerCast(RemapIndices(*targetModule->rootNode, indexCallbacks)); std::string error; sanitizedModule->rootNode = SanitizeInternal(*sanitizedModule->rootNode, &error); if (!sanitizedModule->rootNode) throw ShaderLang::CompilerModuleCompilationFailedError{ node.sourceLocation, node.moduleName, error }; moduleIndex = m_context->modules.size(); assert(m_context->modules.size() == moduleIndex); auto& moduleData = m_context->modules.emplace_back(); // Don't run dependency checker when partially sanitizing if (!m_context->options.allowPartialSanitization) { moduleData.dependenciesVisitor = std::make_unique(); moduleData.dependenciesVisitor->Register(*sanitizedModule->rootNode); } moduleData.environment = std::move(moduleEnvironment); assert(m_context->currentModule->importedModules.size() == moduleIndex); auto& importedModule = m_context->currentModule->importedModules.emplace_back(); importedModule.identifier = identifier; importedModule.module = std::move(sanitizedModule); m_context->currentEnv = std::move(previousEnv); RegisterModule(identifier, moduleIndex); m_context->moduleByUuid[moduleUuid] = moduleIndex; } else { // Module has already been imported moduleIndex = it->second; if (moduleIndex == Context::ModuleIdSentinel) throw ShaderLang::CompilerCircularImportError{ node.sourceLocation, node.moduleName }; } auto& moduleData = m_context->modules[moduleIndex]; auto& exportedSet = moduleData.exportedSetByModule[m_context->currentEnv->moduleId]; // Extract exported nodes and their dependencies std::vector aliasStatements; AstExportVisitor::Callbacks callbacks; callbacks.onExportedFunc = [&](DeclareFunctionStatement& node) { assert(node.funcIndex); if (moduleData.dependenciesVisitor) moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex); if (!exportedSet.usedFunctions.UnboundedTest(*node.funcIndex)) { exportedSet.usedFunctions.UnboundedSet(*node.funcIndex); aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.name, ShaderBuilder::Function(*node.funcIndex))); } }; callbacks.onExportedStruct = [&](DeclareStructStatement& node) { assert(node.structIndex); if (moduleData.dependenciesVisitor) moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex); if (!exportedSet.usedStructs.UnboundedTest(*node.structIndex)) { exportedSet.usedStructs.UnboundedSet(*node.structIndex); aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.description.name, ShaderBuilder::StructType(*node.structIndex))); } }; AstExportVisitor exportVisitor; exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks); if (aliasStatements.empty()) return ShaderBuilder::NoOp(); // Register aliases for (auto& aliasPtr : aliasStatements) Validate(*aliasPtr); if (m_context->options.removeAliases) return ShaderBuilder::NoOp(); // Generate alias statements MultiStatementPtr aliasBlock = std::make_unique(); for (auto& aliasPtr : aliasStatements) aliasBlock->statements.push_back(std::move(aliasPtr)); m_context->allowUnknownIdentifiers = true; //< if module uses a unresolved and non-exported symbol, we need to allow unknown identifiers return aliasBlock; } StatementPtr SanitizeVisitor::Clone(MultiStatement& node) { auto clone = std::make_unique(); clone->statements.reserve(node.statements.size()); std::vector* previousList = m_context->currentStatementList; m_context->currentStatementList = &clone->statements; for (auto& statement : node.statements) clone->statements.push_back(AstCloner::Clone(MandatoryStatement(statement, node.sourceLocation))); m_context->currentStatementList = previousList; return clone; } StatementPtr SanitizeVisitor::Clone(ScopedStatement& node) { MandatoryStatement(node.statement, node.sourceLocation); PushScope(); auto scopedClone = AstCloner::Clone(node); PopScope(); return scopedClone; } StatementPtr SanitizeVisitor::Clone(WhileStatement& node) { MandatoryExpr(node.condition, node.sourceLocation); MandatoryStatement(node.body, node.sourceLocation); auto clone = StaticUniquePointerCast(AstCloner::Clone(node)); if (Validate(*clone) == ValidationResult::Unresolved) return clone; if (clone->unroll.HasValue()) { if (ComputeExprValue(clone->unroll, node.sourceLocation) == ValidationResult::Validated && clone->unroll.GetResultingValue() == LoopUnroll::Always) throw ShaderLang::CompilerWhileUnrollNotSupportedError{ node.sourceLocation }; } return clone; } auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const IdentifierData* { return FindIdentifier(*m_context->currentEnv, identifierName); } template auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const IdentifierData* { return FindIdentifier(*m_context->currentEnv, identifierName, std::forward(functor)); } auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName) const -> const IdentifierData* { auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); if (it == environment.identifiersInScope.rend()) { if (environment.parentEnv) return FindIdentifier(*environment.parentEnv, identifierName); else return nullptr; } return &it->target; } template auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const -> const IdentifierData* { auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) { if (identifier.name == identifierName) { if (functor(identifier.target)) return true; } return false; }); if (it == environment.identifiersInScope.rend()) { if (environment.parentEnv) return FindIdentifier(*environment.parentEnv, identifierName, std::forward(functor)); else return nullptr; } return &it->target; } const ExpressionType* SanitizeVisitor::GetExpressionType(Expression& expr) const { const ExpressionType* expressionType = ShaderAst::GetExpressionType(expr); if (!expressionType) { if (!m_context->options.allowPartialSanitization) throw ShaderLang::AstInternalError{ expr.sourceLocation, "unexpected missing expression type" }; } return expressionType; } const ExpressionType& SanitizeVisitor::GetExpressionTypeSecure(Expression& expr) const { const ExpressionType* expressionType = GetExpressionType(expr); if (!expressionType) throw ShaderLang::AstInternalError{ expr.sourceLocation, "unexpected missing expression type" }; return *expressionType; } ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData, const ShaderLang::SourceLocation& sourceLocation) { switch (identifierData->category) { case IdentifierCategory::Alias: { AliasValueExpression aliasValue; aliasValue.aliasId = identifierData->index; aliasValue.sourceLocation = sourceLocation; return Clone(aliasValue); } case IdentifierCategory::Constant: { // Replace IdentifierExpression by Constant(Value)Expression ConstantExpression constantExpr; constantExpr.constantId = identifierData->index; constantExpr.sourceLocation = sourceLocation; return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression } case IdentifierCategory::Function: { // Replace IdentifierExpression by FunctionExpression auto funcExpr = std::make_unique(); funcExpr->cachedExpressionType = FunctionType{ identifierData->index }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type funcExpr->funcId = identifierData->index; funcExpr->sourceLocation = sourceLocation; return funcExpr; } case IdentifierCategory::Intrinsic: { IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index, sourceLocation); // Replace IdentifierExpression by IntrinsicFunctionExpression auto intrinsicExpr = std::make_unique(); intrinsicExpr->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type intrinsicExpr->intrinsicId = identifierData->index; intrinsicExpr->sourceLocation = sourceLocation; return intrinsicExpr; } case IdentifierCategory::Module: throw ShaderLang::AstUnexpectedIdentifierError{ sourceLocation, "module" }; case IdentifierCategory::Struct: { // Replace IdentifierExpression by StructTypeExpression auto structExpr = std::make_unique(); structExpr->cachedExpressionType = StructType{ identifierData->index }; structExpr->sourceLocation = sourceLocation; structExpr->structTypeId = identifierData->index; return structExpr; } case IdentifierCategory::Type: { auto typeExpr = std::make_unique(); typeExpr->cachedExpressionType = Type{ identifierData->index }; typeExpr->sourceLocation = sourceLocation; typeExpr->typeId = identifierData->index; return typeExpr; } case IdentifierCategory::Unresolved: throw ShaderLang::AstUnexpectedIdentifierError{ sourceLocation, "unresolved" }; case IdentifierCategory::Variable: { // Replace IdentifierExpression by VariableExpression auto varExpr = std::make_unique(); varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index, sourceLocation); varExpr->sourceLocation = sourceLocation; varExpr->variableId = identifierData->index; return varExpr; } } throw ShaderLang::AstInternalError{ sourceLocation, "unhandled identifier category" }; } void SanitizeVisitor::PushScope() { auto& scope = m_context->currentEnv->scopes.emplace_back(); scope.previousSize = m_context->currentEnv->identifiersInScope.size(); } void SanitizeVisitor::PopScope() { assert(!m_context->currentEnv->scopes.empty()); auto& scope = m_context->currentEnv->scopes.back(); m_context->currentEnv->identifiersInScope.resize(scope.previousSize); m_context->currentEnv->scopes.pop_back(); } ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression) { // No need to cache LValues (variables/constants) (TODO: Improve this, as constants don't need to be cached as well) if (GetExpressionCategory(*expression) == ExpressionCategory::LValue) return expression; assert(m_context->currentStatementList); auto variableDeclaration = ShaderBuilder::DeclareVariable("cachedResult", std::move(expression)); //< Validation will prevent name-clash if required Validate(*variableDeclaration); auto varExpr = std::make_unique(); varExpr->sourceLocation = variableDeclaration->initialExpression->sourceLocation; varExpr->variableId = *variableDeclaration->varIndex; m_context->currentStatementList->push_back(std::move(variableDeclaration)); return varExpr; } std::optional SanitizeVisitor::ComputeConstantValue(Expression& expr) const { // Run optimizer on constant value to hopefully retrieve a single constant value ExpressionPtr optimizedExpr = PropagateConstants(expr); if (optimizedExpr->GetType() != NodeType::ConstantValueExpression) { if (!m_context->options.allowPartialSanitization) throw ShaderLang::CompilerConstantExpressionRequiredError{ expr.sourceLocation }; return std::nullopt; } return static_cast(*optimizedExpr).value; } template auto SanitizeVisitor::ComputeExprValue(ExpressionValue& attribute, const ShaderLang::SourceLocation& sourceLocation) const -> ValidationResult { if (!attribute.HasValue()) throw ShaderLang::AstAttributeRequiresValueError{ sourceLocation }; if (attribute.IsExpression()) { auto& expr = *attribute.GetExpression(); std::optional value = ComputeConstantValue(expr); if (!value) return ValidationResult::Unresolved; if constexpr (TypeListFind) { if (!std::holds_alternative(*value)) { // HAAAAAX if (std::holds_alternative(*value) && std::is_same_v) attribute = static_cast(std::get(*value)); else throw ShaderLang::CompilerAttributeUnexpectedTypeError{ expr.sourceLocation }; } else attribute = std::get(*value); } else throw ShaderLang::CompilerAttributeUnexpectedExpressionError{ expr.sourceLocation }; } return ValidationResult::Validated; } template auto SanitizeVisitor::ComputeExprValue(const ExpressionValue& attribute, ExpressionValue& targetAttribute, const ShaderLang::SourceLocation& sourceLocation) -> ValidationResult { if (!attribute.HasValue()) throw ShaderLang::AstAttributeRequiresValueError{ sourceLocation }; if (attribute.IsExpression()) { auto& expr = *attribute.GetExpression(); std::optional value = ComputeConstantValue(*attribute.GetExpression()); if (!value) { targetAttribute = AstCloner::Clone(*attribute.GetExpression()); return ValidationResult::Unresolved; } if constexpr (TypeListFind) { if (!std::holds_alternative(*value)) { // HAAAAAX if (std::holds_alternative(*value) && std::is_same_v) targetAttribute = static_cast(std::get(*value)); else throw ShaderLang::CompilerAttributeUnexpectedTypeError{ expr.sourceLocation }; } else targetAttribute = std::get(*value); } else throw ShaderLang::CompilerAttributeUnexpectedExpressionError{ expr.sourceLocation }; } else { assert(attribute.IsResultingValue()); targetAttribute = attribute.GetResultingValue(); } return ValidationResult::Validated; } template std::unique_ptr SanitizeVisitor::PropagateConstants(T& node) const { AstConstantPropagationVisitor::Options optimizerOptions; optimizerOptions.constantQueryCallback = [&](std::size_t constantId) -> const ConstantValue* { const ConstantValue* value = m_context->constantValues.TryRetrieve(constantId, node.sourceLocation); if (!value && !m_context->options.allowPartialSanitization) throw ShaderLang::AstInvalidConstantIndexError{ node.sourceLocation, constantId }; return value; }; // Run optimizer on constant value to hopefully retrieve a single constant value return StaticUniquePointerCast(ShaderAst::PropagateConstants(node, optimizerOptions)); } void SanitizeVisitor::PreregisterIndices(const Module& module) { // If AST has been sanitized before and is sanitized again but with different options that may introduce new variables (for example reduceLoopsToWhile) // we have to make sure we won't override variable indices. This is done by visiting the AST a first time and preregistering all indices. // TODO: Only do this is the AST has been already sanitized, maybe using a flag stored in the module? AstReflect::Callbacks registerCallbacks; registerCallbacks.onAliasIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->aliases.PreregisterIndex(index, sourceLocation); }; registerCallbacks.onConstIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->constantValues.PreregisterIndex(index, sourceLocation); }; registerCallbacks.onFunctionIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->functions.PreregisterIndex(index, sourceLocation); }; registerCallbacks.onOptionIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->constantValues.PreregisterIndex(index, sourceLocation); }; registerCallbacks.onStructIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->structs.PreregisterIndex(index, sourceLocation); }; registerCallbacks.onVariableIndex = [this](const std::string& /*name*/, std::size_t index, const ShaderLang::SourceLocation& sourceLocation) { m_context->variableTypes.PreregisterIndex(index, sourceLocation); }; AstReflect reflectVisitor; for (const auto& importedModule : module.importedModules) reflectVisitor.Reflect(*importedModule.module->rootNode, registerCallbacks); reflectVisitor.Reflect(*module.rootNode, registerCallbacks); } void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) { auto& funcData = m_context->functions.Retrieve(funcIndex, {}); funcData.flags |= flags; for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) PropagateFunctionFlags(i, funcData.flags, seen); } void SanitizeVisitor::RegisterBuiltin() { // Primitive types RegisterType("bool", PrimitiveType::Boolean, std::nullopt, {}); RegisterType("f32", PrimitiveType::Float32, std::nullopt, {}); RegisterType("i32", PrimitiveType::Int32, std::nullopt, {}); RegisterType("u32", PrimitiveType::UInt32, std::nullopt, {}); // Partial types // Array RegisterType("array", PartialType { { TypeParameterCategory::FullType, TypeParameterCategory::ConstantValue }, [=](const TypeParameter* parameters, std::size_t parameterCount, const ShaderLang::SourceLocation& sourceLocation) -> ExpressionType { assert(parameterCount == 2); assert(std::holds_alternative(parameters[0])); assert(std::holds_alternative(parameters[1])); const ExpressionType& exprType = std::get(parameters[0]); const ConstantValue& length = std::get(parameters[1]); UInt32 lengthValue; if (std::holds_alternative(length)) { Int32 value = std::get(length); if (value <= 0) throw ShaderLang::CompilerArrayLengthError{ sourceLocation, std::to_string(value) }; lengthValue = SafeCast(value); } else if (std::holds_alternative(length)) { lengthValue = std::get(length); if (lengthValue == 0) throw ShaderLang::CompilerArrayLengthError{ sourceLocation, std::to_string(lengthValue) }; } else throw ShaderLang::CompilerArrayLengthError{ sourceLocation, ToString(GetConstantType(length), sourceLocation) }; ArrayType arrayType; arrayType.containedType = std::make_unique(); arrayType.containedType->type = exprType; arrayType.length = lengthValue; return arrayType; } }, std::nullopt, {}); // matX for (std::size_t componentCount = 2; componentCount <= 4; ++componentCount) { RegisterType("mat" + std::to_string(componentCount), PartialType { { TypeParameterCategory::PrimitiveType }, [=](const TypeParameter* parameters, std::size_t parameterCount, const ShaderLang::SourceLocation& /*sourceLocation*/) -> ExpressionType { assert(parameterCount == 1); assert(std::holds_alternative(*parameters)); const ExpressionType& exprType = std::get(*parameters); assert(IsPrimitiveType(exprType)); return MatrixType { componentCount, componentCount, std::get(exprType) }; } }, std::nullopt, {}); } // vecX for (std::size_t componentCount = 2; componentCount <= 4; ++componentCount) { RegisterType("vec" + std::to_string(componentCount), PartialType { { TypeParameterCategory::PrimitiveType }, [=](const TypeParameter* parameters, std::size_t parameterCount, const ShaderLang::SourceLocation& /*sourceLocation*/) -> ExpressionType { assert(parameterCount == 1); assert(std::holds_alternative(*parameters)); const ExpressionType& exprType = std::get(*parameters); assert(IsPrimitiveType(exprType)); return VectorType { componentCount, std::get(exprType) }; } }, std::nullopt, {}); } // samplers struct SamplerInfo { std::string typeName; ImageType imageType; }; std::array samplerInfos = { { { "sampler2D", ImageType::E2D }, { "samplerCube", ImageType::Cubemap } } }; for (SamplerInfo& sampler : samplerInfos) { RegisterType(std::move(sampler.typeName), PartialType { { TypeParameterCategory::PrimitiveType }, [=](const TypeParameter* parameters, std::size_t parameterCount, const ShaderLang::SourceLocation& sourceLocation) -> ExpressionType { assert(parameterCount == 1); assert(std::holds_alternative(*parameters)); const ExpressionType& exprType = std::get(*parameters); assert(IsPrimitiveType(exprType)); PrimitiveType primitiveType = std::get(exprType); // TODO: Add support for integer samplers if (primitiveType != PrimitiveType::Float32) throw ShaderLang::CompilerSamplerUnexpectedTypeError{ sourceLocation, ToString(exprType, sourceLocation) }; return SamplerType { sampler.imageType, primitiveType }; } }, std::nullopt, {}); } // uniform RegisterType("uniform", PartialType { { TypeParameterCategory::StructType }, [=](const TypeParameter* parameters, std::size_t parameterCount, const ShaderLang::SourceLocation& /*sourceLocation*/) -> ExpressionType { assert(parameterCount == 1); assert(std::holds_alternative(*parameters)); const ExpressionType& exprType = std::get(*parameters); assert(IsStructType(exprType)); StructType structType = std::get(exprType); return UniformType { structType }; } }, std::nullopt, {}); // Intrinsics RegisterIntrinsic("cross", IntrinsicType::CrossProduct); RegisterIntrinsic("dot", IntrinsicType::DotProduct); RegisterIntrinsic("exp", IntrinsicType::Exp); RegisterIntrinsic("length", IntrinsicType::Length); RegisterIntrinsic("max", IntrinsicType::Max); RegisterIntrinsic("min", IntrinsicType::Min); RegisterIntrinsic("normalize", IntrinsicType::Normalize); RegisterIntrinsic("pow", IntrinsicType::Pow); RegisterIntrinsic("reflect", IntrinsicType::Reflect); } std::size_t SanitizeVisitor::RegisterAlias(std::string name, std::optional aliasData, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t aliasIndex; if (aliasData) aliasIndex = m_context->aliases.Register(std::move(*aliasData), index, sourceLocation); else if (index) { m_context->aliases.PreregisterIndex(*index, sourceLocation); aliasIndex = *index; } else aliasIndex = m_context->aliases.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), aliasIndex, IdentifierCategory::Alias, m_context->inConditionalStatement }); return aliasIndex; } std::size_t SanitizeVisitor::RegisterConstant(std::string name, std::optional value, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t constantIndex; if (value) constantIndex = m_context->constantValues.Register(std::move(*value), index, sourceLocation); else if (index) { m_context->constantValues.PreregisterIndex(*index, sourceLocation); constantIndex = *index; } else constantIndex = m_context->constantValues.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), constantIndex, IdentifierCategory::Constant, m_context->inConditionalStatement }); return constantIndex; } std::size_t SanitizeVisitor::RegisterFunction(std::string name, std::optional funcData, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (auto* identifier = FindIdentifier(name)) { // Functions can be conditionally defined and condition not resolved yet, allow duplicates when partially sanitizing bool duplicate = !m_context->options.allowPartialSanitization; // Functions cannot be declared twice, except for entry ones if their stages are different if (funcData) { if (funcData->node->entryStage.HasValue() && identifier->category == IdentifierCategory::Function) { auto& otherFunction = m_context->functions.Retrieve(identifier->index, sourceLocation); if (funcData->node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) duplicate = false; } } else { if (!m_context->options.allowPartialSanitization) throw ShaderLang::AstInternalError{ sourceLocation, "unexpected missing function data" }; duplicate = false; } if (duplicate) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; } std::size_t functionIndex; if (funcData) functionIndex = m_context->functions.Register(std::move(*funcData), index, sourceLocation); else if (index) { m_context->functions.PreregisterIndex(*index, sourceLocation); functionIndex = *index; } else functionIndex = m_context->functions.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), functionIndex, IdentifierCategory::Function, m_context->inConditionalStatement }); return functionIndex; } std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type) { if (FindIdentifier(name)) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ {}, name }; std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type), std::nullopt, {}); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), intrinsicIndex, IdentifierCategory::Intrinsic, m_context->inConditionalStatement }); return intrinsicIndex; } std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t index) { if (FindIdentifier(moduleIdentifier)) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ {}, moduleIdentifier }; std::size_t moduleIndex = m_context->moduleIndices.Register(index, std::nullopt, {}); m_context->currentEnv->identifiersInScope.push_back({ std::move(moduleIdentifier), moduleIndex, IdentifierCategory::Module, m_context->inConditionalStatement }); return moduleIndex; } std::size_t SanitizeVisitor::RegisterStruct(std::string name, std::optional description, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { bool unresolved = false; if (const IdentifierData* identifierData = FindIdentifier(name)) { if (!m_context->inConditionalStatement || !identifierData->isConditional) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; else unresolved = true; } std::size_t structIndex; if (description) structIndex = m_context->structs.Register(*description, index, sourceLocation); else if (index) { m_context->structs.PreregisterIndex(*index, sourceLocation); structIndex = *index; } else structIndex = m_context->structs.RegisterNewIndex(true); if (!unresolved) { m_context->currentEnv->identifiersInScope.push_back({ std::move(name), { structIndex, IdentifierCategory::Struct, m_context->inConditionalStatement } }); } else RegisterUnresolved(std::move(name)); return structIndex; } std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional expressionType, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t typeIndex; if (expressionType) typeIndex = m_context->types.Register(std::move(*expressionType), index, sourceLocation); else if (index) { m_context->types.PreregisterIndex(*index, sourceLocation); typeIndex = *index; } else typeIndex = m_context->types.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), typeIndex, IdentifierCategory::Type, m_context->inConditionalStatement }); return typeIndex; } std::size_t SanitizeVisitor::RegisterType(std::string name, std::optional partialType, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { if (FindIdentifier(name)) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; std::size_t typeIndex; if (partialType) { NamedPartialType namedPartial; namedPartial.name = name; namedPartial.type = std::move(*partialType); typeIndex = m_context->types.Register(std::move(namedPartial), index, sourceLocation); } else if (index) { m_context->types.PreregisterIndex(*index, sourceLocation); typeIndex = *index; } else typeIndex = m_context->types.RegisterNewIndex(true); m_context->currentEnv->identifiersInScope.push_back({ std::move(name), typeIndex, IdentifierCategory::Type, m_context->inConditionalStatement }); return typeIndex; } void SanitizeVisitor::RegisterUnresolved(std::string name) { m_context->currentEnv->identifiersInScope.push_back({ std::move(name), std::numeric_limits::max(), IdentifierCategory::Unresolved, m_context->inConditionalStatement }); } std::size_t SanitizeVisitor::RegisterVariable(std::string name, std::optional type, std::optional index, const ShaderLang::SourceLocation& sourceLocation) { bool unresolved = false; if (auto* identifier = FindIdentifier(name)) { // Allow variable shadowing if (identifier->category != IdentifierCategory::Variable) throw ShaderLang::CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; else if (identifier->isConditional && m_context->inConditionalStatement) unresolved = true; //< right variable isn't know from this point } std::size_t varIndex; if (type) varIndex = m_context->variableTypes.Register(std::move(*type), index, sourceLocation); else if (index) { m_context->variableTypes.PreregisterIndex(*index, sourceLocation); varIndex = *index; } else varIndex = m_context->variableTypes.RegisterNewIndex(true); if (!unresolved) { m_context->currentEnv->identifiersInScope.push_back({ std::move(name), { varIndex, IdentifierCategory::Variable, m_context->inConditionalStatement } }); } else RegisterUnresolved(std::move(name)); return varIndex; } auto SanitizeVisitor::ResolveAliasIdentifier(const Identifier* identifier, const ShaderLang::SourceLocation& sourceLocation) const -> const Identifier* { while (identifier->target.category == IdentifierCategory::Alias) identifier = &m_context->aliases.Retrieve(identifier->target.index, sourceLocation); return identifier; } void SanitizeVisitor::ResolveFunctions() { // Once every function is known, we can evaluate function content for (auto& pendingFunc : m_context->pendingFunctions) { PushScope(); for (auto& parameter : pendingFunc.cloneNode->parameters) { if (!m_context->options.allowPartialSanitization || parameter.type.IsResultingValue()) { parameter.varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue(), parameter.varIndex, parameter.sourceLocation); SanitizeIdentifier(parameter.name); } else RegisterUnresolved(parameter.name); } CurrentFunctionData tempFuncData; if (pendingFunc.cloneNode->entryStage.HasValue()) tempFuncData.stageType = pendingFunc.cloneNode->entryStage.GetResultingValue(); m_context->currentFunction = &tempFuncData; std::vector* previousList = m_context->currentStatementList; m_context->currentStatementList = &pendingFunc.cloneNode->statements; pendingFunc.cloneNode->statements.reserve(pendingFunc.node->statements.size()); for (auto& statement : pendingFunc.node->statements) pendingFunc.cloneNode->statements.push_back(CloneStatement(MandatoryStatement(statement, pendingFunc.cloneNode->sourceLocation))); m_context->currentStatementList = previousList; m_context->currentFunction = nullptr; std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex; for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) { auto& targetFunc = m_context->functions.Retrieve(i, pendingFunc.cloneNode->sourceLocation); targetFunc.calledByFunctions.UnboundedSet(funcIndex); } PopScope(); } m_context->pendingFunctions.clear(); Bitset<> seen; for (const auto& [funcIndex, funcData] : m_context->functions.values) { PropagateFunctionFlags(funcIndex, funcData.flags, seen); seen.Clear(); } for (const auto& [funcIndex, funcData] : m_context->functions.values) { if (funcData.flags.Test(FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) throw ShaderLang::CompilerDiscardOutsideOfFragmentStageError{ funcData.node->sourceLocation, funcData.node->entryStage.GetResultingValue() }; } } std::size_t SanitizeVisitor::ResolveStruct(const AliasType& aliasType, const ShaderLang::SourceLocation& sourceLocation) { return ResolveStruct(aliasType.targetType->type, sourceLocation); } std::size_t SanitizeVisitor::ResolveStruct(const ExpressionType& exprType, const ShaderLang::SourceLocation& sourceLocation) { return std::visit([&](auto&& arg) -> std::size_t { using T = std::decay_t; if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) return ResolveStruct(arg, sourceLocation); else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { throw ShaderLang::CompilerStructExpectedError{ sourceLocation, ToString(exprType, sourceLocation) }; } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, exprType); } std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType, const ShaderLang::SourceLocation& /*sourceLocation*/) { return structType.structIndex; } std::size_t SanitizeVisitor::ResolveStruct(const UniformType& uniformType, const ShaderLang::SourceLocation& /*sourceLocation*/) { return uniformType.containedType.structIndex; } ExpressionType SanitizeVisitor::ResolveType(const ExpressionType& exprType, bool resolveAlias, const ShaderLang::SourceLocation& sourceLocation) { if (!IsTypeExpression(exprType)) { if (resolveAlias || m_context->options.removeAliases) return ResolveAlias(exprType); else return exprType; } std::size_t typeIndex = std::get(exprType).typeIndex; const auto& type = m_context->types.Retrieve(typeIndex, sourceLocation); if (!std::holds_alternative(type)) throw ShaderLang::CompilerFullTypeExpectedError{ sourceLocation, ToString(type, sourceLocation) }; return std::get(type); } std::optional SanitizeVisitor::ResolveTypeExpr(const ExpressionValue& exprTypeValue, bool resolveAlias, const ShaderLang::SourceLocation& sourceLocation) { if (!exprTypeValue.HasValue()) return NoType{}; if (exprTypeValue.IsResultingValue()) return ResolveType(exprTypeValue.GetResultingValue(), resolveAlias, sourceLocation); assert(exprTypeValue.IsExpression()); ExpressionPtr expression = CloneExpression(exprTypeValue.GetExpression()); const ExpressionType* exprType = GetExpressionType(*expression); if (!exprType) return std::nullopt; //if (!IsTypeType(exprType)) // throw AstError{ "type expected" }; return ResolveType(*exprType, resolveAlias, sourceLocation); } void SanitizeVisitor::SanitizeIdentifier(std::string& identifier) { // Append _ until the identifier is no longer found while (m_context->options.reservedIdentifiers.find(identifier) != m_context->options.reservedIdentifiers.end()) { do { identifier += "_"; } while (FindIdentifier(identifier) != nullptr); } } MultiStatementPtr SanitizeVisitor::SanitizeInternal(MultiStatement& rootNode, std::string* error) { MultiStatementPtr output; { // First pass, evaluate everything except function code try { output = StaticUniquePointerCast(AstCloner::Clone(rootNode)); } catch (const std::runtime_error& err) { if (!error) throw; *error = err.what(); } ResolveFunctions(); } return output; } std::string SanitizeVisitor::ToString(const ExpressionType& exprType, const ShaderLang::SourceLocation& sourceLocation) const { Stringifier stringifier; stringifier.aliasStringifier = [&](std::size_t aliasIndex) { return m_context->aliases.Retrieve(aliasIndex, sourceLocation).name; }; stringifier.structStringifier = [&](std::size_t structIndex) { return m_context->structs.Retrieve(structIndex, sourceLocation)->name; }; stringifier.typeStringifier = [&](std::size_t typeIndex) { return ToString(m_context->types.Retrieve(typeIndex, sourceLocation), sourceLocation); }; return ShaderAst::ToString(exprType, stringifier); } std::string SanitizeVisitor::ToString(const NamedPartialType& partialType, const ShaderLang::SourceLocation& /*sourceLocation*/) const { return partialType.name + " (partial)"; } template std::string SanitizeVisitor::ToString(const std::variant& value, const ShaderLang::SourceLocation& sourceLocation) const { return std::visit([&](auto&& arg) { return ToString(arg, sourceLocation); }, value); } void SanitizeVisitor::TypeMustMatch(const ExpressionType& left, const ExpressionType& right, const ShaderLang::SourceLocation& sourceLocation) const { if (ResolveAlias(left) != ResolveAlias(right)) throw ShaderLang::CompilerUnmatchingTypesError{ sourceLocation, ToString(left, sourceLocation), ToString(right, sourceLocation) }; } auto SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right, const ShaderLang::SourceLocation& sourceLocation) -> ValidationResult { const ExpressionType* leftType = GetExpressionType(*left); const ExpressionType* rightType = GetExpressionType(*right); if (!leftType || !rightType) return ValidationResult::Unresolved; TypeMustMatch(*leftType, *rightType, sourceLocation); return ValidationResult::Validated; } auto SanitizeVisitor::Validate(DeclareAliasStatement& node) -> ValidationResult { if (node.name.empty()) throw ShaderLang::AstEmptyIdentifierError{ node.sourceLocation }; const ExpressionType* exprType = GetExpressionType(*node.expression); if (!exprType) return ValidationResult::Unresolved; const ExpressionType& resolvedType = ResolveAlias(*exprType); Identifier aliasIdentifier; aliasIdentifier.name = node.name; if (IsStructType(resolvedType)) { std::size_t structIndex = ResolveStruct(resolvedType, node.expression->sourceLocation); aliasIdentifier.target = { structIndex, IdentifierCategory::Struct }; } else if (IsFunctionType(resolvedType)) { std::size_t funcIndex = std::get(resolvedType).funcIndex; aliasIdentifier.target = { funcIndex, IdentifierCategory::Function }; } else if (IsAliasType(resolvedType)) { const AliasType& alias = std::get(resolvedType); aliasIdentifier.target = { alias.aliasIndex, IdentifierCategory::Alias }; } else throw ShaderLang::CompilerAliasUnexpectedTypeError{ node.sourceLocation, ToString(*exprType, node.expression->sourceLocation) }; node.aliasIndex = RegisterAlias(node.name, std::move(aliasIdentifier), node.aliasIndex, node.sourceLocation); return ValidationResult::Validated; } auto SanitizeVisitor::Validate(WhileStatement& node) -> ValidationResult { const ExpressionType* conditionType = GetExpressionType(MandatoryExpr(node.condition, node.sourceLocation)); MandatoryStatement(node.body, node.sourceLocation); if (!conditionType) return ValidationResult::Unresolved; if (ResolveAlias(*conditionType) != ExpressionType{ PrimitiveType::Boolean }) throw ShaderLang::CompilerConditionExpectedBoolError{ node.condition->sourceLocation, ToString(*conditionType, node.condition->sourceLocation) }; return ValidationResult::Validated; } auto SanitizeVisitor::Validate(AccessIndexExpression& node) -> ValidationResult { const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expr, node.sourceLocation)); if (!exprType) return ValidationResult::Unresolved; ExpressionType resolvedExprType = ResolveAlias(*exprType); if (IsTypeExpression(resolvedExprType)) { std::size_t typeIndex = std::get(resolvedExprType).typeIndex; const auto& type = m_context->types.Retrieve(typeIndex, node.sourceLocation); if (!std::holds_alternative(type)) throw ShaderLang::CompilerExpectedPartialTypeError{ node.sourceLocation, ToString(std::get(type), node.sourceLocation) }; const auto& partialType = std::get(type); if (partialType.type.parameters.size() != node.indices.size()) throw ShaderLang::CompilerPartialTypeParameterCountMismatchError{ node.sourceLocation, SafeCast(partialType.type.parameters.size()), SafeCast(node.indices.size()) }; StackVector parameters = NazaraStackVector(TypeParameter, partialType.type.parameters.size()); for (std::size_t i = 0; i < partialType.type.parameters.size(); ++i) { const ExpressionPtr& indexExpr = node.indices[i]; switch (partialType.type.parameters[i]) { case TypeParameterCategory::ConstantValue: { std::optional value = ComputeConstantValue(*indexExpr); if (!value.has_value()) return ValidationResult::Unresolved; parameters.push_back(std::move(*value)); break; } case TypeParameterCategory::FullType: case TypeParameterCategory::PrimitiveType: case TypeParameterCategory::StructType: { const ExpressionType* indexExprType = GetExpressionType(*indexExpr); if (!indexExprType) return ValidationResult::Unresolved; ExpressionType resolvedType = ResolveType(*indexExprType, true, node.sourceLocation); switch (partialType.type.parameters[i]) { case TypeParameterCategory::PrimitiveType: { if (!IsPrimitiveType(resolvedType)) throw ShaderLang::CompilerPartialTypeExpectError{ indexExpr->sourceLocation, "primitive", SafeCast(i) }; break; } case TypeParameterCategory::StructType: { if (!IsStructType(resolvedType)) throw ShaderLang::CompilerPartialTypeExpectError{ indexExpr->sourceLocation, "struct", SafeCast(i) }; break; } default: break; } parameters.push_back(resolvedType); break; } } } assert(parameters.size() == partialType.type.parameters.size()); node.cachedExpressionType = partialType.type.buildFunc(parameters.data(), parameters.size(), node.sourceLocation); } else { if (node.indices.size() != 1) throw ShaderLang::AstNoIndexError{ node.sourceLocation }; for (const auto& indexExpr : node.indices) { const ExpressionType* indexType = GetExpressionType(*indexExpr); if (!indexType) return ValidationResult::Unresolved; if (!IsPrimitiveType(*indexType)) throw ShaderLang::CompilerIndexRequiresIntegerIndicesError{ node.sourceLocation, ToString(*indexType, indexExpr->sourceLocation) }; PrimitiveType primitiveIndexType = std::get(*indexType); if (primitiveIndexType != PrimitiveType::Int32 && primitiveIndexType != PrimitiveType::UInt32) throw ShaderLang::CompilerIndexRequiresIntegerIndicesError{ node.sourceLocation, ToString(*indexType, indexExpr->sourceLocation) }; if (IsArrayType(resolvedExprType)) { const ArrayType& arrayType = std::get(resolvedExprType); ExpressionType containedType = arrayType.containedType->type; //< Don't overwrite exprType directly since it contains arrayType resolvedExprType = std::move(containedType); } else if (IsStructType(resolvedExprType)) { if (primitiveIndexType != PrimitiveType::Int32) throw ShaderLang::CompilerIndexStructRequiresInt32IndicesError{ node.sourceLocation, ToString(*indexType, indexExpr->sourceLocation) }; ConstantValueExpression& constantExpr = static_cast(*indexExpr); Int32 index = std::get(constantExpr.value); std::size_t structIndex = ResolveStruct(resolvedExprType, indexExpr->sourceLocation); const StructDescription* s = m_context->structs.Retrieve(structIndex, indexExpr->sourceLocation); std::optional resolvedExprTypeOpt = ResolveTypeExpr(s->members[index].type, true, indexExpr->sourceLocation); if (!resolvedExprTypeOpt.has_value()) return ValidationResult::Unresolved; resolvedExprType = std::move(resolvedExprTypeOpt).value(); } else if (IsMatrixType(resolvedExprType)) { // Matrix index (ex: mat[2]) MatrixType matrixType = std::get(resolvedExprType); //TODO: Handle row-major matrices resolvedExprType = VectorType{ matrixType.rowCount, matrixType.type }; } else if (IsVectorType(resolvedExprType)) { // Swizzle expression with one component (ex: vec[2]) VectorType swizzledVec = std::get(resolvedExprType); resolvedExprType = swizzledVec.type; } else throw ShaderLang::CompilerIndexUnexpectedTypeError{ node.sourceLocation, ToString(*indexType, indexExpr->sourceLocation) }; } node.cachedExpressionType = std::move(resolvedExprType); } return ValidationResult::Validated; } auto SanitizeVisitor::Validate(AssignExpression& node) -> ValidationResult { const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left, node.sourceLocation)); if (!leftExprType) return ValidationResult::Unresolved; const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right, node.sourceLocation)); if (!rightExprType) return ValidationResult::Unresolved; if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) throw ShaderLang::CompilerAssignTemporaryError{ node.sourceLocation }; std::optional binaryType; switch (node.op) { case AssignType::Simple: if (TypeMustMatch(node.left, node.right, node.sourceLocation) == ValidationResult::Unresolved) return ValidationResult::Unresolved; break; case AssignType::CompoundAdd: binaryType = BinaryType::Add; break; case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break; case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break; case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break; case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break; case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break; } if (binaryType) { ExpressionType expressionType = ValidateBinaryOp(*binaryType, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType), node.sourceLocation); TypeMustMatch(*leftExprType, expressionType, node.sourceLocation); if (m_context->options.removeCompoundAssignments) { node.op = AssignType::Simple; node.right = ShaderBuilder::Binary(*binaryType, AstCloner::Clone(*node.left), std::move(node.right)); node.right->cachedExpressionType = std::move(expressionType); } } node.cachedExpressionType = *leftExprType; return ValidationResult::Validated; } auto SanitizeVisitor::Validate(BinaryExpression& node) -> ValidationResult { const ExpressionType* leftExprType = GetExpressionType(MandatoryExpr(node.left, node.sourceLocation)); if (!leftExprType) return ValidationResult::Unresolved; const ExpressionType* rightExprType = GetExpressionType(MandatoryExpr(node.right, node.sourceLocation)); if (!rightExprType) return ValidationResult::Unresolved; node.cachedExpressionType = ValidateBinaryOp(node.op, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType), node.sourceLocation); return ValidationResult::Validated; } auto SanitizeVisitor::Validate(CallFunctionExpression& node) -> ValidationResult { std::size_t targetFuncIndex; if (node.targetFunction->GetType() == NodeType::FunctionExpression) targetFuncIndex = static_cast(*node.targetFunction).funcId; else if (node.targetFunction->GetType() == NodeType::AliasValueExpression) { const auto& alias = static_cast(*node.targetFunction); const Identifier* aliasIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId, node.sourceLocation), node.sourceLocation); if (aliasIdentifier->target.category != IdentifierCategory::Function) throw ShaderLang::CompilerFunctionCallExpectedFunctionError{ node.sourceLocation }; targetFuncIndex = aliasIdentifier->target.index; } else throw ShaderLang::CompilerFunctionCallExpectedFunctionError{ node.sourceLocation }; auto& funcData = m_context->functions.Retrieve(targetFuncIndex, node.sourceLocation); const DeclareFunctionStatement* referenceDeclaration = funcData.node; if (referenceDeclaration->entryStage.HasValue()) throw ShaderLang::CompilerFunctionCallUnexpectedEntryFunctionError{ node.sourceLocation, referenceDeclaration->name }; for (std::size_t i = 0; i < node.parameters.size(); ++i) { const ExpressionType* parameterType = GetExpressionType(*node.parameters[i]); if (!parameterType) return ValidationResult::Unresolved; if (ResolveAlias(*parameterType) != ResolveAlias(referenceDeclaration->parameters[i].type.GetResultingValue())) throw ShaderLang::CompilerFunctionCallUnmatchingParameterTypeError{ node.sourceLocation, referenceDeclaration->name, SafeCast(i), ToString(referenceDeclaration->parameters[i].type.GetResultingValue(), referenceDeclaration->parameters[i].sourceLocation), ToString(*parameterType, node.parameters[i]->sourceLocation) }; } if (node.parameters.size() != referenceDeclaration->parameters.size()) throw ShaderLang::CompilerFunctionCallUnmatchingParameterCountError{ node.sourceLocation, referenceDeclaration->name, SafeCast(referenceDeclaration->parameters.size()), SafeCast(node.parameters.size()) }; node.cachedExpressionType = referenceDeclaration->returnType.GetResultingValue(); return ValidationResult::Validated; } auto SanitizeVisitor::Validate(CastExpression& node) -> ValidationResult { std::optional targetTypeOpt = ResolveTypeExpr(node.targetType, false, node.sourceLocation); if (!targetTypeOpt) return ValidationResult::Unresolved; const ExpressionType& targetType = ResolveAlias(*targetTypeOpt); auto& firstExprPtr = MandatoryExpr(node.expressions.front(), node.sourceLocation); std::size_t expressionCount = 0; for (; expressionCount < node.expressions.size(); ++expressionCount) { if (!node.expressions[expressionCount]) break; } if (IsMatrixType(targetType)) { const MatrixType& targetMatrixType = std::get(targetType); const ExpressionType* firstExprType = GetExpressionType(firstExprPtr); if (!firstExprType) return ValidationResult::Unresolved; if (IsMatrixType(ResolveAlias(*firstExprType))) { if (expressionCount != 1) throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(expressionCount), 1 }; // Matrix to matrix cast: always valid } else { // Matrix builder (from vectors) assert(targetMatrixType.columnCount <= 4); if (expressionCount != targetMatrixType.columnCount) throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(expressionCount), SafeCast(targetMatrixType.columnCount) }; for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i) { const auto& exprPtr = node.expressions[i]; assert(exprPtr); const ExpressionType* exprType = GetExpressionType(*exprPtr); if (!exprType) return ValidationResult::Unresolved; const ExpressionType& resolvedExprType = ResolveAlias(*exprType); if (!IsVectorType(resolvedExprType)) throw ShaderLang::CompilerCastMatrixExpectedVectorError{ node.sourceLocation, ToString(resolvedExprType, node.expressions[i]->sourceLocation) }; const VectorType& vecType = std::get(resolvedExprType); if (vecType.componentCount != targetMatrixType.rowCount) throw ShaderLang::CompilerCastMatrixVectorComponentMismatchError{ node.expressions[i]->sourceLocation, SafeCast(vecType.componentCount), SafeCast(targetMatrixType.rowCount) }; if (vecType.type != targetMatrixType.type) throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ node.expressions[i]->sourceLocation, ToString(targetMatrixType.type, node.sourceLocation), ToString(vecType.type, node.sourceLocation) }; } } } else if (IsPrimitiveType(targetType)) { // Cast between primitive types if (expressionCount != 1) throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(expressionCount), 1 }; const ExpressionType* fromType = GetExpressionType(*node.expressions[0]); if (!fromType) return ValidationResult::Unresolved; const ExpressionType& resolvedFromType = ResolveAlias(*fromType); if (!IsPrimitiveType(resolvedFromType)) throw ShaderLang::CompilerCastIncompatibleTypesError{ node.expressions[0]->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedFromType, node.sourceLocation) }; PrimitiveType fromPrimitiveType = std::get(resolvedFromType); PrimitiveType targetPrimitiveType = std::get(targetType); bool areTypeCompatibles = [&] { switch (targetPrimitiveType) { case PrimitiveType::Boolean: case PrimitiveType::String: return false; case PrimitiveType::Float32: { switch (fromPrimitiveType) { case PrimitiveType::Boolean: case PrimitiveType::String: return false; case PrimitiveType::Float32: case PrimitiveType::Int32: case PrimitiveType::UInt32: return true; } } case PrimitiveType::Int32: { switch (fromPrimitiveType) { case PrimitiveType::Boolean: case PrimitiveType::String: case PrimitiveType::UInt32: return false; case PrimitiveType::Float32: case PrimitiveType::Int32: return true; } } case PrimitiveType::UInt32: { switch (fromPrimitiveType) { case PrimitiveType::Boolean: case PrimitiveType::String: return false; case PrimitiveType::Float32: case PrimitiveType::Int32: case PrimitiveType::UInt32: return true; } } } throw ShaderLang::AstInternalError{ node.sourceLocation, "unexpected cast from " + ShaderAst::ToString(fromPrimitiveType) + " to " + ShaderAst::ToString(targetPrimitiveType) }; }(); if (!areTypeCompatibles) throw ShaderLang::CompilerCastIncompatibleTypesError{ node.expressions[0]->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedFromType, node.sourceLocation) }; } else if (IsVectorType(targetType)) { PrimitiveType targetBaseType = std::get(targetType).type; auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t { if (IsVectorType(exprType)) return std::get(exprType).componentCount; else { assert(IsPrimitiveType(exprType)); return 1; } }; std::size_t componentCount = 0; std::size_t requiredComponents = GetComponentCount(targetType); for (auto& exprPtr : node.expressions) { if (!exprPtr) break; const ExpressionType* exprType = GetExpressionType(*exprPtr); if (!exprType) return ValidationResult::Unresolved; const ExpressionType& resolvedExprType = ResolveAlias(*exprType); if (IsPrimitiveType(resolvedExprType)) { PrimitiveType primitiveType = std::get(resolvedExprType); if (primitiveType != targetBaseType) throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ exprPtr->sourceLocation, ToString(targetBaseType, node.sourceLocation), ToString(primitiveType, exprPtr->sourceLocation) }; } else if (IsVectorType(resolvedExprType)) { PrimitiveType primitiveType = std::get(resolvedExprType).type; if (primitiveType != targetBaseType) throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ exprPtr->sourceLocation, ToString(targetBaseType, node.sourceLocation), ToString(primitiveType, exprPtr->sourceLocation) }; } else throw ShaderLang::CompilerCastIncompatibleTypesError{ exprPtr->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedExprType, exprPtr->sourceLocation) }; componentCount += GetComponentCount(resolvedExprType); } if (componentCount != requiredComponents) throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(componentCount), SafeCast(requiredComponents) }; } else throw ShaderLang::CompilerInvalidCastError{ node.sourceLocation, ToString(targetType, node.sourceLocation) }; node.cachedExpressionType = targetType; node.targetType = targetType; return ValidationResult::Validated; } auto SanitizeVisitor::Validate(DeclareVariableStatement& node) -> ValidationResult { ExpressionType resolvedType; if (!node.varType.HasValue()) { if (!node.initialExpression) throw ShaderLang::CompilerVarDeclarationMissingTypeAndValueError{ node.sourceLocation }; const ExpressionType* initialExprType = GetExpressionType(*node.initialExpression); if (!initialExprType) { RegisterUnresolved(node.varName); return ValidationResult::Unresolved; } resolvedType = *initialExprType; } else { std::optional varType = ResolveTypeExpr(node.varType, false, node.sourceLocation); if (!varType) { RegisterUnresolved(node.varName); return ValidationResult::Unresolved; } resolvedType = std::move(varType).value(); if (node.initialExpression) { const ExpressionType* initialExprType = GetExpressionType(*node.initialExpression); if (!initialExprType) { RegisterUnresolved(node.varName); return ValidationResult::Unresolved; } TypeMustMatch(resolvedType, *initialExprType, node.sourceLocation); } } node.varIndex = RegisterVariable(node.varName, resolvedType, node.varIndex, node.sourceLocation); node.varType = std::move(resolvedType); if (m_context->options.makeVariableNameUnique) { // Since we are registered, FindIdentifier will find us auto IgnoreOurself = [varIndex = *node.varIndex](const IdentifierData& identifierData) { if (identifierData.category == IdentifierCategory::Variable && identifierData.index == varIndex) return false; return true; }; if (FindIdentifier(node.varName, IgnoreOurself) != nullptr) { // Try to make variable name unique by appending _X to its name (incrementing X until it's unique) to the variable name until by incrementing X unsigned int cloneIndex = 2; std::string candidateName; do { candidateName = node.varName + "_" + std::to_string(cloneIndex++); } while (FindIdentifier(candidateName, IgnoreOurself) != nullptr); node.varName = std::move(candidateName); } } SanitizeIdentifier(node.varName); return ValidationResult::Validated; } auto SanitizeVisitor::Validate(IntrinsicExpression& node) -> ValidationResult { auto IsFloatingPointVector = [](const ExpressionType& type) { return type == ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }; }; auto CheckNotBoolean = [](Expression& expression, const ExpressionType& type) { if ((IsPrimitiveType(type) && std::get(type) == PrimitiveType::Boolean) || (IsVectorType(type) && std::get(type).type == PrimitiveType::Boolean)) throw ShaderLang::CompilerIntrinsicUnexpectedBooleanError{ expression.sourceLocation }; }; auto CheckFloatingPoint = [](Expression& expression, const ExpressionType& type) { if ((IsPrimitiveType(type) && std::get(type) != PrimitiveType::Float32) || (IsVectorType(type) && std::get(type).type != PrimitiveType::Float32)) throw ShaderLang::CompilerIntrinsicExpectedFloatError{ expression.sourceLocation }; }; auto SetReturnTypeToFirstParameterType = [&] { node.cachedExpressionType = GetExpressionTypeSecure(*node.parameters.front()); return ValidationResult::Validated; }; auto SetReturnTypeToFirstParameterInnerType = [&] { node.cachedExpressionType = std::get(GetExpressionTypeSecure(*node.parameters.front())).type; return ValidationResult::Validated; }; auto IsUnresolved = [](ValidationResult result) { return result == ValidationResult::Unresolved; }; // Parameter validation and return type attribution switch (node.intrinsic) { case IntrinsicType::CrossProduct: if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector, "floating-point vector"))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterType(); case IntrinsicType::DotProduct: if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector, "floating-point vector"))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterInnerType(); case IntrinsicType::Exp: if (IsUnresolved(ValidateIntrinsicParamCount<1>(node)) || IsUnresolved(ValidateIntrinsicParameter<0>(node, CheckFloatingPoint))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterType(); case IntrinsicType::Length: if (IsUnresolved(ValidateIntrinsicParamCount<1>(node)) || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector, "floating-point vector"))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterInnerType(); case IntrinsicType::Max: case IntrinsicType::Min: if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) || IsUnresolved(ValidateIntrinsicParameter<0>(node, CheckNotBoolean))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterType(); case IntrinsicType::Normalize: if (IsUnresolved(ValidateIntrinsicParamCount<1>(node)) || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector, "floating-point vector"))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterType(); case IntrinsicType::Pow: if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) || IsUnresolved(ValidateIntrinsicParameter<0>(node, CheckFloatingPoint))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterType(); case IntrinsicType::Reflect: if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) || IsUnresolved(ValidateIntrinsicParamMatchingType(node)) || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsFloatingPointVector, "floating-point vector"))) return ValidationResult::Unresolved; return SetReturnTypeToFirstParameterType(); case IntrinsicType::SampleTexture: { if (IsUnresolved(ValidateIntrinsicParamCount<2>(node)) || IsUnresolved(ValidateIntrinsicParameterType<0>(node, IsSamplerType, "sampler type"))) return ValidationResult::Unresolved; // Special check: vector dimensions must match sample type const SamplerType& samplerType = std::get(ResolveAlias(GetExpressionTypeSecure(*node.parameters[0]))); std::size_t requiredComponentCount = 0; switch (samplerType.dim) { case ImageType::E1D: requiredComponentCount = 1; break; case ImageType::E1D_Array: case ImageType::E2D: requiredComponentCount = 2; break; case ImageType::E2D_Array: case ImageType::E3D: case ImageType::Cubemap: requiredComponentCount = 3; break; } if (requiredComponentCount == 0) throw ShaderLang::AstInternalError{ node.parameters[0]->sourceLocation, "unhandled sampler dimensions" }; auto IsRightType = [=](const ExpressionType& type) { return type == ExpressionType{ VectorType{ requiredComponentCount, PrimitiveType::Float32 } }; }; if (IsUnresolved(ValidateIntrinsicParameterType<1>(node, IsRightType, "sampler of requirement components"))) return ValidationResult::Unresolved; node.cachedExpressionType = VectorType{ 4, samplerType.sampledType }; return ValidationResult::Validated; } } throw ShaderLang::AstInternalError{ node.sourceLocation, "unhandled intrinsic" }; } auto SanitizeVisitor::Validate(SwizzleExpression& node) -> ValidationResult { const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expression, node.sourceLocation)); if (!exprType) return ValidationResult::Unresolved; const ExpressionType& resolvedExprType = ResolveAlias(*exprType); if (!IsPrimitiveType(resolvedExprType) && !IsVectorType(resolvedExprType)) throw ShaderLang::CompilerSwizzleUnexpectedTypeError{ node.sourceLocation, ToString(*exprType, node.expression->sourceLocation) }; PrimitiveType baseType; std::size_t componentCount; if (IsPrimitiveType(resolvedExprType)) { if (m_context->options.removeScalarSwizzling) throw ShaderLang::AstInternalError{ node.sourceLocation, "scalar swizzling should have been removed before validating" }; baseType = std::get(resolvedExprType); componentCount = 1; } else { const VectorType& vecType = std::get(resolvedExprType); baseType = vecType.type; componentCount = vecType.componentCount; } if (node.componentCount > 4) throw ShaderLang::CompilerInvalidSwizzleError{ node.sourceLocation }; for (std::size_t i = 0; i < node.componentCount; ++i) { if (node.components[i] >= componentCount) throw ShaderLang::CompilerInvalidSwizzleError{ node.sourceLocation }; } if (node.componentCount > 1) { node.cachedExpressionType = VectorType{ node.componentCount, baseType }; } else node.cachedExpressionType = baseType; return ValidationResult::Validated; } auto SanitizeVisitor::Validate(UnaryExpression& node) -> ValidationResult { const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expression, node.sourceLocation)); if (!exprType) return ValidationResult::Unresolved; const ExpressionType& resolvedExprType = ResolveAlias(*exprType); switch (node.op) { case UnaryType::LogicalNot: { if (resolvedExprType != ExpressionType(PrimitiveType::Boolean)) throw ShaderLang::CompilerUnaryUnsupportedError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) }; break; } case UnaryType::Minus: case UnaryType::Plus: { PrimitiveType basicType; if (IsPrimitiveType(resolvedExprType)) basicType = std::get(resolvedExprType); else if (IsVectorType(resolvedExprType)) basicType = std::get(resolvedExprType).type; else throw ShaderLang::CompilerUnaryUnsupportedError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) }; if (basicType != PrimitiveType::Float32 && basicType != PrimitiveType::Int32 && basicType != PrimitiveType::UInt32) throw ShaderLang::CompilerUnaryUnsupportedError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) }; break; } } node.cachedExpressionType = *exprType; return ValidationResult::Validated; } auto SanitizeVisitor::Validate(VariableValueExpression& node) -> ValidationResult { node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId, node.sourceLocation); return ValidationResult::Validated; } ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType, const ShaderLang::SourceLocation& sourceLocation) { if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType)) throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "right", ToString(rightExprType, sourceLocation) }; if (IsPrimitiveType(leftExprType)) { PrimitiveType leftType = std::get(leftExprType); switch (op) { case BinaryType::CompGe: case BinaryType::CompGt: case BinaryType::CompLe: case BinaryType::CompLt: if (leftType == PrimitiveType::Boolean) throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; [[fallthrough]]; case BinaryType::CompEq: case BinaryType::CompNe: { TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; } case BinaryType::Add: case BinaryType::Subtract: TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; case BinaryType::Multiply: case BinaryType::Divide: { switch (leftType) { case PrimitiveType::Float32: case PrimitiveType::Int32: case PrimitiveType::UInt32: { if (IsMatrixType(rightExprType)) { TypeMustMatch(leftType, std::get(rightExprType).type, sourceLocation); return rightExprType; } else if (IsPrimitiveType(rightExprType)) { TypeMustMatch(leftType, rightExprType, sourceLocation); return leftExprType; } else if (IsVectorType(rightExprType)) { TypeMustMatch(leftType, std::get(rightExprType).type, sourceLocation); return rightExprType; } else throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation, ToString(leftExprType, sourceLocation), ToString(rightExprType, sourceLocation) }; break; } case PrimitiveType::Boolean: throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; default: throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation, ToString(leftExprType, sourceLocation), ToString(rightExprType, sourceLocation) }; } } case BinaryType::LogicalAnd: case BinaryType::LogicalOr: { if (leftType != PrimitiveType::Boolean) throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; } } } else if (IsMatrixType(leftExprType)) { const MatrixType& leftType = std::get(leftExprType); switch (op) { case BinaryType::CompGe: case BinaryType::CompGt: case BinaryType::CompLe: case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; case BinaryType::Add: case BinaryType::Subtract: TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; case BinaryType::Multiply: case BinaryType::Divide: { if (IsMatrixType(rightExprType)) { TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; //< FIXME } else if (IsPrimitiveType(rightExprType)) { TypeMustMatch(leftType.type, rightExprType, sourceLocation); return leftExprType; } else if (IsVectorType(rightExprType)) { const VectorType& rightType = std::get(rightExprType); TypeMustMatch(leftType.type, rightType.type, sourceLocation); if (leftType.columnCount != rightType.componentCount) throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation, ToString(leftExprType, sourceLocation), ToString(rightExprType, sourceLocation) }; return rightExprType; } else throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation, ToString(leftExprType, sourceLocation), ToString(rightExprType, sourceLocation) }; } case BinaryType::LogicalAnd: case BinaryType::LogicalOr: throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; } } else if (IsVectorType(leftExprType)) { const VectorType& leftType = std::get(leftExprType); switch (op) { case BinaryType::CompGe: case BinaryType::CompGt: case BinaryType::CompLe: case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: TypeMustMatch(leftExprType, rightExprType, sourceLocation); return PrimitiveType::Boolean; case BinaryType::Add: case BinaryType::Subtract: TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; case BinaryType::Multiply: case BinaryType::Divide: { if (IsPrimitiveType(rightExprType)) { TypeMustMatch(leftType.type, rightExprType, sourceLocation); return leftExprType; } else if (IsVectorType(rightExprType)) { TypeMustMatch(leftType, rightExprType, sourceLocation); return rightExprType; } else throw ShaderLang::CompilerBinaryIncompatibleTypesError{ sourceLocation, ToString(leftExprType, sourceLocation), ToString(rightExprType, sourceLocation) }; break; } case BinaryType::LogicalAnd: case BinaryType::LogicalOr: throw ShaderLang::CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; } } throw ShaderLang::AstInternalError{ sourceLocation, "unchecked operation" }; } template auto SanitizeVisitor::ValidateIntrinsicParamCount(IntrinsicExpression& node) -> ValidationResult { if (node.parameters.size() != N) throw ShaderLang::CompilerIntrinsicExpectedParameterCountError{ node.sourceLocation, SafeCast(N) }; for (auto& param : node.parameters) MandatoryExpr(param, node.sourceLocation); return ValidationResult::Validated; } auto SanitizeVisitor::ValidateIntrinsicParamMatchingType(IntrinsicExpression& node) -> ValidationResult { const ExpressionType* firstParameterType = GetExpressionType(*node.parameters.front()); if (!firstParameterType) return ValidationResult::Unresolved; for (std::size_t i = 1; i < node.parameters.size(); ++i) { const ExpressionType* parameterType = GetExpressionType(*node.parameters[i]); if (!parameterType) return ValidationResult::Unresolved; if (ResolveAlias(*firstParameterType) != ResolveAlias(*parameterType)) throw ShaderLang::CompilerIntrinsicUnmatchingParameterTypeError{ node.parameters[i]->sourceLocation }; } return ValidationResult::Validated; } template auto SanitizeVisitor::ValidateIntrinsicParameter(IntrinsicExpression& node, F&& func) -> ValidationResult { assert(node.parameters.size() > N); auto& parameter = MandatoryExpr(node.parameters[N], node.sourceLocation); const ExpressionType* type = GetExpressionType(parameter); if (!type) return ValidationResult::Unresolved; const ExpressionType& resolvedType = ResolveAlias(*type); func(parameter, resolvedType); return ValidationResult::Validated; } template auto SanitizeVisitor::ValidateIntrinsicParameterType(IntrinsicExpression& node, F&& func, const char* typeStr) -> ValidationResult { assert(node.parameters.size() > N); auto& parameter = MandatoryExpr(node.parameters[N], node.sourceLocation); const ExpressionType* type = GetExpressionType(parameter); if (!type) return ValidationResult::Unresolved; const ExpressionType& resolvedType = ResolveAlias(*type); if (!func(resolvedType)) throw ShaderLang::CompilerIntrinsicExpectedTypeError{ parameter.sourceLocation, SafeCast(N), typeStr, ToString(*type, parameter.sourceLocation)}; return ValidationResult::Validated; } Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node, const ShaderLang::SourceLocation& sourceLocation) { if (!node) throw ShaderLang::AstMissingExpressionError{ sourceLocation }; return *node; } Statement& SanitizeVisitor::MandatoryStatement(const StatementPtr& node, const ShaderLang::SourceLocation& sourceLocation) { if (!node) throw ShaderLang::AstMissingStatementError{ sourceLocation }; return *node; } StatementPtr SanitizeVisitor::Unscope(StatementPtr node) { assert(node); if (node->GetType() == NodeType::ScopedStatement) return std::move(static_cast(*node).statement); else return node; } UInt32 SanitizeVisitor::ToSwizzleIndex(char c, const ShaderLang::SourceLocation& sourceLocation) { switch (c) { case 'r': case 'x': case 's': return 0u; case 'g': case 'y': case 't': return 1u; case 'b': case 'z': case 'p': return 2u; case 'a': case 'w': case 'q': return 3u; default: throw ShaderLang::CompilerInvalidSwizzleError{ sourceLocation, std::string(&c, 1) }; } } }