// Copyright (C) 2020 Jérôme Leclercq // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include #include #include #include namespace Nz::ShaderAst { namespace { struct AstError { std::string errMsg; }; template std::unique_ptr static_unique_pointer_cast(std::unique_ptr&& ptr) { return std::unique_ptr(static_cast(ptr.release())); } } struct SanitizeVisitor::Context { Options options; std::array entryFunctions = {}; std::unordered_set declaredExternalVar; std::unordered_set usedBindingIndexes; }; StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& nodePtr, const Options& options, std::string* error) { StatementPtr clone; Context currentContext; currentContext.options = options; m_context = ¤tContext; CallOnExit resetContext([&] { m_context = nullptr; }); PushScope(); //< Global scope { try { clone = AstCloner::Clone(nodePtr); } catch (const AstError& err) { if (!error) throw std::runtime_error(err.errMsg); *error = err.errMsg; } } PopScope(); return clone; } const ExpressionType& SanitizeVisitor::CheckField(const ExpressionType& structType, const std::string* memberIdentifier, std::size_t remainingMembers, std::size_t* structIndices) { std::size_t structIndex = ResolveStruct(structType); *structIndices++ = structIndex; assert(structIndex < m_structs.size()); const StructDescription& s = m_structs[structIndex]; auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); if (memberIt == s.members.end()) throw AstError{ "unknown field " + memberIdentifier[0] }; const auto& member = *memberIt; if (remainingMembers > 1) return CheckField(member.type, memberIdentifier + 1, remainingMembers - 1, structIndices); else return member.type; } ExpressionPtr SanitizeVisitor::Clone(AccessMemberIdentifierExpression& node) { auto structExpr = CloneExpression(MandatoryExpr(node.structExpr)); const ExpressionType& exprType = GetExpressionType(*structExpr); if (IsVectorType(exprType)) { const VectorType& swizzledVec = std::get(exprType); // Swizzle expression auto swizzle = std::make_unique(); swizzle->expression = std::move(structExpr); // FIXME: Handle properly multiple identifiers (treat recursively) if (node.memberIdentifiers.size() != 1) throw AstError{ "invalid swizzle" }; const std::string& swizzleStr = node.memberIdentifiers.front(); if (swizzleStr.empty() || swizzleStr.size() > swizzle->components.size()) throw AstError{ "invalid swizzle" }; swizzle->componentCount = swizzleStr.size(); if (swizzle->componentCount > 1) swizzle->cachedExpressionType = VectorType{ swizzle->componentCount, swizzledVec.type }; else swizzle->cachedExpressionType = swizzledVec.type; for (std::size_t i = 0; i < swizzle->componentCount; ++i) { switch (swizzleStr[i]) { case 'r': case 'x': case 's': swizzle->components[i] = SwizzleComponent::First; break; case 'g': case 'y': case 't': swizzle->components[i] = SwizzleComponent::Second; break; case 'b': case 'z': case 'p': swizzle->components[i] = SwizzleComponent::Third; break; case 'a': case 'w': case 'q': swizzle->components[i] = SwizzleComponent::Fourth; break; } } return swizzle; } // Transform to AccessMemberIndexExpression auto accessMemberIndex = std::make_unique(); accessMemberIndex->structExpr = std::move(structExpr); StackArray structIndices = NazaraStackArrayNoInit(std::size_t, node.memberIdentifiers.size()); accessMemberIndex->cachedExpressionType = ResolveType(CheckField(exprType, node.memberIdentifiers.data(), node.memberIdentifiers.size(), structIndices.data())); accessMemberIndex->memberIndices.resize(node.memberIdentifiers.size()); for (std::size_t i = 0; i < node.memberIdentifiers.size(); ++i) { std::size_t structIndex = structIndices[i]; assert(structIndex < m_structs.size()); const StructDescription& structDesc = m_structs[structIndex]; auto it = std::find_if(structDesc.members.begin(), structDesc.members.end(), [&](const auto& member) { return member.name == node.memberIdentifiers[i]; }); assert(it != structDesc.members.end()); accessMemberIndex->memberIndices[i] = std::distance(structDesc.members.begin(), it); } return accessMemberIndex; } ExpressionPtr SanitizeVisitor::Clone(AssignExpression& node) { MandatoryExpr(node.left); MandatoryExpr(node.right); if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) throw AstError{ "Assignation is only possible with a l-value" }; auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); TypeMustMatch(clone->left, clone->right); clone->cachedExpressionType = GetExpressionType(*clone->right); return clone; } ExpressionPtr SanitizeVisitor::Clone(BinaryExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); const ExpressionType& leftExprType = GetExpressionType(MandatoryExpr(clone->left)); if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; const ExpressionType& rightExprType = GetExpressionType(MandatoryExpr(clone->right)); if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType)) throw AstError{ "right expression type does not support binary operation" }; if (IsPrimitiveType(leftExprType)) { PrimitiveType leftType = std::get(leftExprType); switch (clone->op) { case BinaryType::CompGe: case BinaryType::CompGt: case BinaryType::CompLe: case BinaryType::CompLt: if (leftType == PrimitiveType::Boolean) throw AstError{ "this operation is not supported for booleans" }; TypeMustMatch(clone->left, clone->right); clone->cachedExpressionType = PrimitiveType::Boolean; break; case BinaryType::Add: case BinaryType::CompEq: case BinaryType::CompNe: case BinaryType::Subtract: TypeMustMatch(clone->left, clone->right); clone->cachedExpressionType = leftExprType; break; case BinaryType::Multiply: case BinaryType::Divide: { switch (leftType) { case PrimitiveType::Float32: case PrimitiveType::Int32: case PrimitiveType::UInt32: { if (IsMatrixType(rightExprType)) { TypeMustMatch(leftType, std::get(rightExprType).type); clone->cachedExpressionType = rightExprType; } else if (IsPrimitiveType(rightExprType)) { TypeMustMatch(leftType, rightExprType); clone->cachedExpressionType = leftExprType; } else if (IsVectorType(rightExprType)) { TypeMustMatch(leftType, std::get(rightExprType).type); clone->cachedExpressionType = rightExprType; } else throw AstError{ "incompatible types" }; break; } case PrimitiveType::Boolean: throw AstError{ "this operation is not supported for booleans" }; default: throw AstError{ "incompatible types" }; } } } } else if (IsMatrixType(leftExprType)) { const MatrixType& leftType = std::get(leftExprType); switch (clone->op) { case BinaryType::CompGe: case BinaryType::CompGt: case BinaryType::CompLe: case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: TypeMustMatch(clone->left, clone->right); clone->cachedExpressionType = PrimitiveType::Boolean; break; case BinaryType::Add: case BinaryType::Subtract: TypeMustMatch(clone->left, clone->right); clone->cachedExpressionType = leftExprType; break; case BinaryType::Multiply: case BinaryType::Divide: { if (IsMatrixType(rightExprType)) { TypeMustMatch(leftExprType, rightExprType); clone->cachedExpressionType = leftExprType; //< FIXME } else if (IsPrimitiveType(rightExprType)) { TypeMustMatch(leftType.type, rightExprType); clone->cachedExpressionType = leftExprType; } else if (IsVectorType(rightExprType)) { const VectorType& rightType = std::get(rightExprType); TypeMustMatch(leftType.type, rightType.type); if (leftType.columnCount != rightType.componentCount) throw AstError{ "incompatible types" }; clone->cachedExpressionType = rightExprType; } else throw AstError{ "incompatible types" }; } } } else if (IsVectorType(leftExprType)) { const VectorType& leftType = std::get(leftExprType); switch (clone->op) { case BinaryType::CompGe: case BinaryType::CompGt: case BinaryType::CompLe: case BinaryType::CompLt: case BinaryType::CompEq: case BinaryType::CompNe: TypeMustMatch(clone->left, clone->right); clone->cachedExpressionType = PrimitiveType::Boolean; break; case BinaryType::Add: case BinaryType::Subtract: TypeMustMatch(clone->left, clone->right); clone->cachedExpressionType = leftExprType; break; case BinaryType::Multiply: case BinaryType::Divide: { if (IsPrimitiveType(rightExprType)) { TypeMustMatch(leftType.type, rightExprType); clone->cachedExpressionType = leftExprType; } else if (IsVectorType(rightExprType)) { TypeMustMatch(leftType, rightExprType); clone->cachedExpressionType = rightExprType; } else throw AstError{ "incompatible types" }; } } } return clone; } ExpressionPtr SanitizeVisitor::Clone(CastExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t { if (IsVectorType(exprType)) return std::get(exprType).componentCount; else { assert(IsPrimitiveType(exprType)); return 1; } }; std::size_t componentCount = 0; std::size_t requiredComponents = GetComponentCount(clone->targetType); for (auto& exprPtr : clone->expressions) { if (!exprPtr) break; const ExpressionType& exprType = GetExpressionType(*exprPtr); if (!IsPrimitiveType(exprType) && !IsVectorType(exprType)) throw AstError{ "incompatible type" }; componentCount += GetComponentCount(exprType); } if (componentCount != requiredComponents) throw AstError{ "component count doesn't match required component count" }; clone->targetType = ResolveType(clone->targetType); clone->cachedExpressionType = clone->targetType; return clone; } ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node) { MandatoryExpr(node.truePath); MandatoryExpr(node.falsePath); auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); const ExpressionType& leftExprType = GetExpressionType(*clone->truePath); if (leftExprType != GetExpressionType(*clone->falsePath)) throw AstError{ "true path type must match false path type" }; clone->cachedExpressionType = leftExprType; return clone; } ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); clone->cachedExpressionType = GetExpressionType(clone->value); return clone; } ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node) { assert(m_context); const Identifier* identifier = FindIdentifier(node.identifier); if (!identifier) throw AstError{ "unknown identifier " + node.identifier }; if (!std::holds_alternative(identifier->value)) throw AstError{ "expected variable identifier" }; const Variable& variable = std::get(identifier->value); // Replace IdentifierExpression by VariableExpression auto varExpr = std::make_unique(); varExpr->cachedExpressionType = m_variables[variable.varIndex]; varExpr->variableId = variable.varIndex; return varExpr; } ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); // Parameter validation switch (clone->intrinsic) { case IntrinsicType::CrossProduct: case IntrinsicType::DotProduct: { if (clone->parameters.size() != 2) throw AstError { "Expected two parameters" }; for (auto& param : clone->parameters) MandatoryExpr(param); const ExpressionType& type = GetExpressionType(*clone->parameters.front()); for (std::size_t i = 1; i < clone->parameters.size(); ++i) { if (type != GetExpressionType(*clone->parameters[i])) throw AstError{ "All type must match" }; } break; } case IntrinsicType::Length: { if (clone->parameters.size() != 1) throw AstError{ "Expected only one parameters" }; for (auto& param : clone->parameters) MandatoryExpr(param); const ExpressionType& type = GetExpressionType(*clone->parameters.front()); if (!IsVectorType(type)) throw AstError{ "Expected a vector" }; break; } case IntrinsicType::SampleTexture: { if (clone->parameters.size() != 2) throw AstError{ "Expected two parameters" }; for (auto& param : clone->parameters) MandatoryExpr(param); if (!IsSamplerType(GetExpressionType(*clone->parameters[0]))) throw AstError{ "First parameter must be a sampler" }; if (!IsVectorType(GetExpressionType(*clone->parameters[1]))) throw AstError{ "Second parameter must be a vector" }; break; } } // Return type attribution switch (clone->intrinsic) { case IntrinsicType::CrossProduct: { const ExpressionType& type = GetExpressionType(*clone->parameters.front()); if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) throw AstError{ "CrossProduct only works with vec3 expressions" }; clone->cachedExpressionType = type; break; } case IntrinsicType::DotProduct: case IntrinsicType::Length: { ExpressionType type = GetExpressionType(*clone->parameters.front()); if (!IsVectorType(type)) throw AstError{ "DotProduct expects vector types" }; clone->cachedExpressionType = std::get(type).type; break; } case IntrinsicType::SampleTexture: { clone->cachedExpressionType = VectorType{ 4, std::get(GetExpressionType(*clone->parameters.front())).sampledType }; break; } } return clone; } ExpressionPtr SanitizeVisitor::Clone(SelectOptionExpression& node) { MandatoryExpr(node.truePath); MandatoryExpr(node.falsePath); auto condExpr = std::make_unique(); condExpr->truePath = CloneExpression(node.truePath); condExpr->falsePath = CloneExpression(node.falsePath); const Identifier* identifier = FindIdentifier(node.optionName); if (!identifier) throw AstError{ "unknown option " + node.optionName }; if (!std::holds_alternative