// Copyright (C) 2020 Jérôme Leclercq // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include #include namespace Nz::ShaderAst { struct AstError { std::string errMsg; }; struct AstValidator::Context { //const ShaderAst::Function* currentFunction; std::optional activeScopeId; AstCache* cache; }; bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache) { try { AstCache dummy; Context currentContext; currentContext.cache = (cache) ? cache : &dummy; m_context = ¤tContext; CallOnExit resetContext([&] { m_context = nullptr; }); EnterScope(); node->Visit(*this); ExitScope(); return true; } catch (const AstError& e) { if (error) *error = e.errMsg; return false; } } Expression& AstValidator::MandatoryExpr(ExpressionPtr& node) { if (!node) throw AstError{ "Invalid expression" }; return *node; } Statement& AstValidator::MandatoryStatement(StatementPtr& node) { if (!node) throw AstError{ "Invalid statement" }; return *node; } void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right) { return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache)); } void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right) { if (left != right) throw AstError{ "Left expression type must match right expression type" }; } ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers) { const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName); if (!identifier) throw AstError{ "unknown identifier " + structName }; if (std::holds_alternative(identifier->value)) throw AstError{ "identifier is not a struct" }; const StructDescription& s = std::get(identifier->value); auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; }); if (memberIt == s.members.end()) throw AstError{ "unknown field " + memberIdentifier[0]}; const auto& member = *memberIt; if (remainingMembers > 1) return CheckField(std::get(member.type), memberIdentifier + 1, remainingMembers - 1); else return member.type; } AstCache::Scope& AstValidator::EnterScope() { std::size_t newScopeId = m_context->cache->scopes.size(); std::optional previousScope = m_context->activeScopeId; auto& newScope = m_context->cache->scopes.emplace_back(); newScope.parentScopeIndex = previousScope; m_context->activeScopeId = newScopeId; return m_context->cache->scopes[newScopeId]; } void AstValidator::ExitScope() { assert(m_context->activeScopeId); auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId]; m_context->activeScopeId = previousScope.parentScopeIndex; } void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType) { m_context->cache->nodeExpressionType[&node] = std::move(expressionType); } void AstValidator::RegisterScope(Node& node) { if (m_context->activeScopeId) m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId; } void AstValidator::Visit(AccessMemberExpression& node) { RegisterScope(node); const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache); if (!IsStructType(exprType)) throw AstError{ "expression is not a structure" }; const std::string& structName = std::get(exprType); RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size())); } void AstValidator::Visit(AssignExpression& node) { RegisterScope(node); MandatoryExpr(node.left); MandatoryExpr(node.right); TypeMustMatch(node.left, node.right); if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue) throw AstError { "Assignation is only possible with a l-value" }; AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(BinaryExpression& node) { RegisterScope(node); // Register expression type AstRecursiveVisitor::Visit(node); const ShaderExpressionType& leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache); if (!IsBasicType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; const ShaderExpressionType& rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache); if (!IsBasicType(rightExprType)) throw AstError{ "right expression type does not support binary operation" }; BasicType leftType = std::get(leftExprType); BasicType rightType = std::get(rightExprType); switch (node.op) { case BinaryType::CompGe: case BinaryType::CompGt: case BinaryType::CompLe: case BinaryType::CompLt: if (leftType == BasicType::Boolean) throw AstError{ "this operation is not supported for booleans" }; [[fallthrough]]; case BinaryType::Add: case BinaryType::CompEq: case BinaryType::CompNe: case BinaryType::Subtract: TypeMustMatch(node.left, node.right); break; case BinaryType::Multiply: case BinaryType::Divide: { switch (leftType) { case BasicType::Float1: case BasicType::Int1: { if (GetComponentType(rightType) != leftType) throw AstError{ "Left expression type is not compatible with right expression type" }; break; } case BasicType::Float2: case BasicType::Float3: case BasicType::Float4: case BasicType::Int2: case BasicType::Int3: case BasicType::Int4: { if (leftType != rightType && rightType != GetComponentType(leftType)) throw AstError{ "Left expression type is not compatible with right expression type" }; break; } case BasicType::Mat4x4: { switch (rightType) { case BasicType::Float1: case BasicType::Float4: case BasicType::Mat4x4: break; default: TypeMustMatch(node.left, node.right); } break; } default: TypeMustMatch(node.left, node.right); break; } } } } void AstValidator::Visit(CastExpression& node) { RegisterScope(node); unsigned int componentCount = 0; unsigned int requiredComponents = GetComponentCount(node.targetType); for (auto& exprPtr : node.expressions) { if (!exprPtr) break; ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache); if (!IsBasicType(exprType)) throw AstError{ "incompatible type" }; componentCount += GetComponentCount(std::get(exprType)); } if (componentCount != requiredComponents) throw AstError{ "component count doesn't match required component count" }; AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(ConditionalExpression& node) { MandatoryExpr(node.truePath); MandatoryExpr(node.falsePath); RegisterScope(node); AstRecursiveVisitor::Visit(node); //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) // throw AstError{ "condition not found" }; } void AstValidator::Visit(ConstantExpression& node) { RegisterScope(node); } void AstValidator::Visit(IdentifierExpression& node) { assert(m_context); if (!m_context->activeScopeId) throw AstError{ "no scope" }; RegisterScope(node); const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier); if (!identifier) throw AstError{ "Unknown variable " + node.identifier }; } void AstValidator::Visit(IntrinsicExpression& node) { RegisterScope(node); switch (node.intrinsic) { case IntrinsicType::CrossProduct: case IntrinsicType::DotProduct: { if (node.parameters.size() != 2) throw AstError { "Expected 2 parameters" }; for (auto& param : node.parameters) MandatoryExpr(param); ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache); for (std::size_t i = 1; i < node.parameters.size(); ++i) { if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache) throw AstError{ "All type must match" }; } break; } } switch (node.intrinsic) { case IntrinsicType::CrossProduct: { if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache) throw AstError{ "CrossProduct only works with Float3 expressions" }; break; } case IntrinsicType::DotProduct: break; } AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(SwizzleExpression& node) { RegisterScope(node); if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache); if (!IsBasicType(exprType)) throw AstError{ "Cannot swizzle this type" }; switch (std::get(exprType)) { case BasicType::Float1: case BasicType::Float2: case BasicType::Float3: case BasicType::Float4: case BasicType::Int1: case BasicType::Int2: case BasicType::Int3: case BasicType::Int4: break; default: throw AstError{ "Cannot swizzle this type" }; } AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(BranchStatement& node) { RegisterScope(node); for (auto& condStatement : node.condStatements) { const ShaderExpressionType& condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache); if (!IsBasicType(condType) || std::get(condType) != BasicType::Boolean) throw AstError{ "if expression must resolve to boolean type" }; MandatoryStatement(condStatement.statement); } AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(ConditionalStatement& node) { MandatoryStatement(node.statement); RegisterScope(node); AstRecursiveVisitor::Visit(node); //if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition) // throw AstError{ "condition not found" }; } void AstValidator::Visit(DeclareFunctionStatement& node) { auto& scope = EnterScope(); RegisterScope(node); for (auto& parameter : node.parameters) { auto& identifier = scope.identifiers.emplace_back(); identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } }; } for (auto& statement : node.statements) MandatoryStatement(statement).Visit(*this); ExitScope(); } void AstValidator::Visit(DeclareStructStatement& node) { assert(m_context); if (!m_context->activeScopeId) throw AstError{ "cannot declare variable without scope" }; RegisterScope(node); auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; auto& identifier = scope.identifiers.emplace_back(); identifier = AstCache::Identifier{ node.description.name, node.description }; AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(DeclareVariableStatement& node) { assert(m_context); if (!m_context->activeScopeId) throw AstError{ "cannot declare variable without scope" }; RegisterScope(node); auto& scope = m_context->cache->scopes[*m_context->activeScopeId]; auto& identifier = scope.identifiers.emplace_back(); identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } }; AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(ExpressionStatement& node) { RegisterScope(node); MandatoryExpr(node.expression); AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(MultiStatement& node) { assert(m_context); EnterScope(); RegisterScope(node); for (auto& statement : node.statements) MandatoryStatement(statement); ExitScope(); AstRecursiveVisitor::Visit(node); } void AstValidator::Visit(ReturnStatement& node) { RegisterScope(node); /*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void)) { if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType) throw AstError{ "Return type doesn't match function return type" }; } else { if (node.returnExpr) throw AstError{ "Unexpected expression for return (function doesn't return)" }; }*/ AstRecursiveVisitor::Visit(node); } bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache) { AstValidator validator; return validator.Validate(node, error, cache); } }