503 lines
13 KiB
C++
503 lines
13 KiB
C++
// Copyright (C) 2020 Jérôme Leclercq
|
|
// This file is part of the "Nazara Engine - Shader generator"
|
|
// For conditions of distribution and use, see copyright notice in Config.hpp
|
|
|
|
#include <Nazara/Shader/ShaderAstValidator.hpp>
|
|
#include <Nazara/Core/CallOnExit.hpp>
|
|
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
|
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
|
#include <vector>
|
|
#include <Nazara/Shader/Debug.hpp>
|
|
|
|
namespace Nz::ShaderAst
|
|
{
|
|
struct AstError
|
|
{
|
|
std::string errMsg;
|
|
};
|
|
|
|
struct AstValidator::Context
|
|
{
|
|
//const ShaderAst::Function* currentFunction;
|
|
std::optional<std::size_t> activeScopeId;
|
|
AstCache* cache;
|
|
};
|
|
|
|
bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache)
|
|
{
|
|
try
|
|
{
|
|
AstCache dummy;
|
|
|
|
Context currentContext;
|
|
currentContext.cache = (cache) ? cache : &dummy;
|
|
|
|
m_context = ¤tContext;
|
|
CallOnExit resetContext([&] { m_context = nullptr; });
|
|
|
|
EnterScope();
|
|
node->Visit(*this);
|
|
ExitScope();
|
|
|
|
return true;
|
|
}
|
|
catch (const AstError& e)
|
|
{
|
|
if (error)
|
|
*error = e.errMsg;
|
|
|
|
return false;
|
|
}
|
|
}
|
|
|
|
Expression& AstValidator::MandatoryExpr(ExpressionPtr& node)
|
|
{
|
|
if (!node)
|
|
throw AstError{ "Invalid expression" };
|
|
|
|
return *node;
|
|
}
|
|
|
|
Statement& AstValidator::MandatoryStatement(StatementPtr& node)
|
|
{
|
|
if (!node)
|
|
throw AstError{ "Invalid statement" };
|
|
|
|
return *node;
|
|
}
|
|
|
|
void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
|
|
{
|
|
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
|
|
}
|
|
|
|
void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
|
|
{
|
|
if (left != right)
|
|
throw AstError{ "Left expression type must match right expression type" };
|
|
}
|
|
|
|
ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
|
{
|
|
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
|
|
if (!identifier)
|
|
throw AstError{ "unknown identifier " + structName };
|
|
|
|
if (std::holds_alternative<StructDescription>(identifier->value))
|
|
throw AstError{ "identifier is not a struct" };
|
|
|
|
const StructDescription& s = std::get<StructDescription>(identifier->value);
|
|
|
|
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
|
|
if (memberIt == s.members.end())
|
|
throw AstError{ "unknown field " + memberIdentifier[0]};
|
|
|
|
const auto& member = *memberIt;
|
|
|
|
if (remainingMembers > 1)
|
|
return CheckField(std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
|
|
else
|
|
return member.type;
|
|
}
|
|
|
|
AstCache::Scope& AstValidator::EnterScope()
|
|
{
|
|
std::size_t newScopeId = m_context->cache->scopes.size();
|
|
|
|
std::optional<std::size_t> previousScope = m_context->activeScopeId;
|
|
|
|
auto& newScope = m_context->cache->scopes.emplace_back();
|
|
newScope.parentScopeIndex = previousScope;
|
|
|
|
m_context->activeScopeId = newScopeId;
|
|
return m_context->cache->scopes[newScopeId];
|
|
}
|
|
|
|
void AstValidator::ExitScope()
|
|
{
|
|
assert(m_context->activeScopeId);
|
|
auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId];
|
|
m_context->activeScopeId = previousScope.parentScopeIndex;
|
|
}
|
|
|
|
void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType)
|
|
{
|
|
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
|
|
}
|
|
|
|
void AstValidator::RegisterScope(Node& node)
|
|
{
|
|
if (m_context->activeScopeId)
|
|
m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId;
|
|
}
|
|
|
|
void AstValidator::Visit(AccessMemberExpression& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
|
|
if (!IsStructType(exprType))
|
|
throw AstError{ "expression is not a structure" };
|
|
|
|
const std::string& structName = std::get<std::string>(exprType);
|
|
|
|
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
|
|
}
|
|
|
|
void AstValidator::Visit(AssignExpression& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
MandatoryExpr(node.left);
|
|
MandatoryExpr(node.right);
|
|
TypeMustMatch(node.left, node.right);
|
|
|
|
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
|
|
throw AstError { "Assignation is only possible with a l-value" };
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(BinaryExpression& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
// Register expression type
|
|
AstRecursiveVisitor::Visit(node);
|
|
|
|
const ShaderExpressionType& leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
|
|
if (!IsBasicType(leftExprType))
|
|
throw AstError{ "left expression type does not support binary operation" };
|
|
|
|
const ShaderExpressionType& rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
|
|
if (!IsBasicType(rightExprType))
|
|
throw AstError{ "right expression type does not support binary operation" };
|
|
|
|
BasicType leftType = std::get<BasicType>(leftExprType);
|
|
BasicType rightType = std::get<BasicType>(rightExprType);
|
|
|
|
switch (node.op)
|
|
{
|
|
case BinaryType::CompGe:
|
|
case BinaryType::CompGt:
|
|
case BinaryType::CompLe:
|
|
case BinaryType::CompLt:
|
|
if (leftType == BasicType::Boolean)
|
|
throw AstError{ "this operation is not supported for booleans" };
|
|
|
|
[[fallthrough]];
|
|
case BinaryType::Add:
|
|
case BinaryType::CompEq:
|
|
case BinaryType::CompNe:
|
|
case BinaryType::Subtract:
|
|
TypeMustMatch(node.left, node.right);
|
|
break;
|
|
|
|
case BinaryType::Multiply:
|
|
case BinaryType::Divide:
|
|
{
|
|
switch (leftType)
|
|
{
|
|
case BasicType::Float1:
|
|
case BasicType::Int1:
|
|
{
|
|
if (GetComponentType(rightType) != leftType)
|
|
throw AstError{ "Left expression type is not compatible with right expression type" };
|
|
|
|
break;
|
|
}
|
|
|
|
case BasicType::Float2:
|
|
case BasicType::Float3:
|
|
case BasicType::Float4:
|
|
case BasicType::Int2:
|
|
case BasicType::Int3:
|
|
case BasicType::Int4:
|
|
{
|
|
if (leftType != rightType && rightType != GetComponentType(leftType))
|
|
throw AstError{ "Left expression type is not compatible with right expression type" };
|
|
|
|
break;
|
|
}
|
|
|
|
case BasicType::Mat4x4:
|
|
{
|
|
switch (rightType)
|
|
{
|
|
case BasicType::Float1:
|
|
case BasicType::Float4:
|
|
case BasicType::Mat4x4:
|
|
break;
|
|
|
|
default:
|
|
TypeMustMatch(node.left, node.right);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
default:
|
|
TypeMustMatch(node.left, node.right);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void AstValidator::Visit(CastExpression& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
unsigned int componentCount = 0;
|
|
unsigned int requiredComponents = GetComponentCount(node.targetType);
|
|
for (auto& exprPtr : node.expressions)
|
|
{
|
|
if (!exprPtr)
|
|
break;
|
|
|
|
ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
|
|
if (!IsBasicType(exprType))
|
|
throw AstError{ "incompatible type" };
|
|
|
|
componentCount += GetComponentCount(std::get<BasicType>(exprType));
|
|
}
|
|
|
|
if (componentCount != requiredComponents)
|
|
throw AstError{ "component count doesn't match required component count" };
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(ConditionalExpression& node)
|
|
{
|
|
MandatoryExpr(node.truePath);
|
|
MandatoryExpr(node.falsePath);
|
|
|
|
RegisterScope(node);
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
|
// throw AstError{ "condition not found" };
|
|
}
|
|
|
|
void AstValidator::Visit(ConstantExpression& node)
|
|
{
|
|
RegisterScope(node);
|
|
}
|
|
|
|
void AstValidator::Visit(IdentifierExpression& node)
|
|
{
|
|
assert(m_context);
|
|
|
|
if (!m_context->activeScopeId)
|
|
throw AstError{ "no scope" };
|
|
|
|
RegisterScope(node);
|
|
|
|
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier);
|
|
if (!identifier)
|
|
throw AstError{ "Unknown variable " + node.identifier };
|
|
}
|
|
|
|
void AstValidator::Visit(IntrinsicExpression& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
switch (node.intrinsic)
|
|
{
|
|
case IntrinsicType::CrossProduct:
|
|
case IntrinsicType::DotProduct:
|
|
{
|
|
if (node.parameters.size() != 2)
|
|
throw AstError { "Expected 2 parameters" };
|
|
|
|
for (auto& param : node.parameters)
|
|
MandatoryExpr(param);
|
|
|
|
ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
|
|
for (std::size_t i = 1; i < node.parameters.size(); ++i)
|
|
{
|
|
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache)
|
|
throw AstError{ "All type must match" };
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
switch (node.intrinsic)
|
|
{
|
|
case IntrinsicType::CrossProduct:
|
|
{
|
|
if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache)
|
|
throw AstError{ "CrossProduct only works with Float3 expressions" };
|
|
|
|
break;
|
|
}
|
|
|
|
case IntrinsicType::DotProduct:
|
|
break;
|
|
}
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(SwizzleExpression& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
if (node.componentCount > 4)
|
|
throw AstError{ "Cannot swizzle more than four elements" };
|
|
|
|
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
|
|
if (!IsBasicType(exprType))
|
|
throw AstError{ "Cannot swizzle this type" };
|
|
|
|
switch (std::get<BasicType>(exprType))
|
|
{
|
|
case BasicType::Float1:
|
|
case BasicType::Float2:
|
|
case BasicType::Float3:
|
|
case BasicType::Float4:
|
|
case BasicType::Int1:
|
|
case BasicType::Int2:
|
|
case BasicType::Int3:
|
|
case BasicType::Int4:
|
|
break;
|
|
|
|
default:
|
|
throw AstError{ "Cannot swizzle this type" };
|
|
}
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(BranchStatement& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
for (auto& condStatement : node.condStatements)
|
|
{
|
|
const ShaderExpressionType& condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
|
|
if (!IsBasicType(condType) || std::get<BasicType>(condType) != BasicType::Boolean)
|
|
throw AstError{ "if expression must resolve to boolean type" };
|
|
|
|
MandatoryStatement(condStatement.statement);
|
|
}
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(ConditionalStatement& node)
|
|
{
|
|
MandatoryStatement(node.statement);
|
|
|
|
RegisterScope(node);
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
|
// throw AstError{ "condition not found" };
|
|
}
|
|
|
|
void AstValidator::Visit(DeclareFunctionStatement& node)
|
|
{
|
|
auto& scope = EnterScope();
|
|
|
|
RegisterScope(node);
|
|
|
|
for (auto& parameter : node.parameters)
|
|
{
|
|
auto& identifier = scope.identifiers.emplace_back();
|
|
identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } };
|
|
}
|
|
|
|
for (auto& statement : node.statements)
|
|
MandatoryStatement(statement).Visit(*this);
|
|
|
|
ExitScope();
|
|
}
|
|
|
|
void AstValidator::Visit(DeclareStructStatement& node)
|
|
{
|
|
assert(m_context);
|
|
|
|
if (!m_context->activeScopeId)
|
|
throw AstError{ "cannot declare variable without scope" };
|
|
|
|
RegisterScope(node);
|
|
|
|
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
|
|
|
|
auto& identifier = scope.identifiers.emplace_back();
|
|
identifier = AstCache::Identifier{ node.description.name, node.description };
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(DeclareVariableStatement& node)
|
|
{
|
|
assert(m_context);
|
|
|
|
if (!m_context->activeScopeId)
|
|
throw AstError{ "cannot declare variable without scope" };
|
|
|
|
RegisterScope(node);
|
|
|
|
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
|
|
|
|
auto& identifier = scope.identifiers.emplace_back();
|
|
identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } };
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(ExpressionStatement& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
MandatoryExpr(node.expression);
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(MultiStatement& node)
|
|
{
|
|
assert(m_context);
|
|
|
|
EnterScope();
|
|
|
|
RegisterScope(node);
|
|
|
|
for (auto& statement : node.statements)
|
|
MandatoryStatement(statement);
|
|
|
|
ExitScope();
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
void AstValidator::Visit(ReturnStatement& node)
|
|
{
|
|
RegisterScope(node);
|
|
|
|
/*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void))
|
|
{
|
|
if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType)
|
|
throw AstError{ "Return type doesn't match function return type" };
|
|
}
|
|
else
|
|
{
|
|
if (node.returnExpr)
|
|
throw AstError{ "Unexpected expression for return (function doesn't return)" };
|
|
}*/
|
|
|
|
AstRecursiveVisitor::Visit(node);
|
|
}
|
|
|
|
bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache)
|
|
{
|
|
AstValidator validator;
|
|
return validator.Validate(node, error, cache);
|
|
}
|
|
}
|