|
|
|
|
@@ -577,6 +577,18 @@ namespace Nz::ShaderAst
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(ConstantIndexExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.constantId >= m_context->constantValues.size())
|
|
|
|
|
throw AstError{ "invalid constant index " + std::to_string(node.constantId) };
|
|
|
|
|
|
|
|
|
|
// Replace by constant value
|
|
|
|
|
auto constant = ShaderBuilder::Constant(m_context->constantValues[node.constantId]);
|
|
|
|
|
constant->cachedExpressionType = GetExpressionType(constant->value);
|
|
|
|
|
|
|
|
|
|
return constant;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node)
|
|
|
|
|
{
|
|
|
|
|
assert(m_context);
|
|
|
|
|
@@ -712,11 +724,46 @@ namespace Nz::ShaderAst
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(VariableExpression& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.variableId >= m_context->variableTypes.size())
|
|
|
|
|
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
|
|
|
|
|
|
|
|
|
|
node.cachedExpressionType = m_context->variableTypes[node.variableId];
|
|
|
|
|
|
|
|
|
|
return AstCloner::Clone(node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.isConst)
|
|
|
|
|
{
|
|
|
|
|
// Evaluate every condition at compilation and select the right statement
|
|
|
|
|
for (auto& cond : node.condStatements)
|
|
|
|
|
{
|
|
|
|
|
MandatoryExpr(cond.condition);
|
|
|
|
|
|
|
|
|
|
ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition));
|
|
|
|
|
if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean })
|
|
|
|
|
throw AstError{ "expected a boolean value" };
|
|
|
|
|
|
|
|
|
|
if (std::get<bool>(conditionValue))
|
|
|
|
|
return AstCloner::Clone(*cond.statement);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Every condition failed, fallback to else if any
|
|
|
|
|
if (node.elseStatement)
|
|
|
|
|
return AstCloner::Clone(*node.elseStatement);
|
|
|
|
|
else
|
|
|
|
|
return ShaderBuilder::NoOp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto clone = std::make_unique<BranchStatement>();
|
|
|
|
|
clone->condStatements.reserve(node.condStatements.size());
|
|
|
|
|
|
|
|
|
|
if (!m_context->currentFunction)
|
|
|
|
|
throw AstError{ "non-const branching statements can only exist inside a function" };
|
|
|
|
|
|
|
|
|
|
for (auto& cond : node.condStatements)
|
|
|
|
|
{
|
|
|
|
|
PushScope();
|
|
|
|
|
@@ -758,6 +805,31 @@ namespace Nz::ShaderAst
|
|
|
|
|
return ShaderBuilder::NoOp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node)
|
|
|
|
|
{
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareConstStatement>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
if (!clone->expression)
|
|
|
|
|
throw AstError{ "const variables must have an expression" };
|
|
|
|
|
|
|
|
|
|
clone->expression = Optimize(*clone->expression);
|
|
|
|
|
if (clone->expression->GetType() != NodeType::ConstantExpression)
|
|
|
|
|
throw AstError{ "const variable must have constant expressions " };
|
|
|
|
|
|
|
|
|
|
const ConstantValue& value = static_cast<ConstantExpression&>(*clone->expression).value;
|
|
|
|
|
|
|
|
|
|
ExpressionType expressionType = ResolveType(GetExpressionType(value));
|
|
|
|
|
|
|
|
|
|
if (!IsNoType(clone->type) && ResolveType(clone->type) != expressionType)
|
|
|
|
|
throw AstError{ "constant expression doesn't match type" };
|
|
|
|
|
|
|
|
|
|
clone->type = expressionType;
|
|
|
|
|
|
|
|
|
|
clone->constIndex = RegisterConstant(clone->name, value);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node)
|
|
|
|
|
{
|
|
|
|
|
assert(m_context);
|
|
|
|
|
@@ -815,6 +887,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node)
|
|
|
|
|
{
|
|
|
|
|
if (m_context->currentFunction)
|
|
|
|
|
throw AstError{ "a function cannot be defined inside another function" };
|
|
|
|
|
|
|
|
|
|
auto clone = std::make_unique<DeclareFunctionStatement>();
|
|
|
|
|
clone->name = node.name;
|
|
|
|
|
clone->parameters = node.parameters;
|
|
|
|
|
@@ -908,6 +983,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node)
|
|
|
|
|
{
|
|
|
|
|
if (m_context->currentFunction)
|
|
|
|
|
throw AstError{ "options must be declared outside of functions" };
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
|
|
|
|
|
clone->optType = ResolveType(clone->optType);
|
|
|
|
|
|
|
|
|
|
@@ -926,6 +1004,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node)
|
|
|
|
|
{
|
|
|
|
|
if (m_context->currentFunction)
|
|
|
|
|
throw AstError{ "structs must be declared outside of functions" };
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> declaredMembers;
|
|
|
|
|
@@ -961,6 +1042,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node)
|
|
|
|
|
{
|
|
|
|
|
if (!m_context->currentFunction)
|
|
|
|
|
throw AstError{ "global variables outside of external blocks are forbidden" };
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareVariableStatement>(AstCloner::Clone(node));
|
|
|
|
|
if (IsNoType(clone->varType))
|
|
|
|
|
{
|
|
|
|
|
@@ -1092,6 +1176,17 @@ namespace Nz::ShaderAst
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr)
|
|
|
|
|
{
|
|
|
|
|
// Run optimizer on constant value to hopefully retrieve a single constant value
|
|
|
|
|
ExpressionPtr optimizedExpr = Optimize(expr);
|
|
|
|
|
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
|
|
|
|
|
throw AstError{"expected a constant expression"};
|
|
|
|
|
|
|
|
|
|
return static_cast<ConstantExpression&>(*optimizedExpr).value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
|
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
|
|
|
|
|
{
|
|
|
|
|
AstOptimizer::Options optimizerOptions;
|
|
|
|
|
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
|
|
|
|
|
@@ -1103,11 +1198,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
optimizerOptions.enabledOptions = m_context->options.enabledOptions;
|
|
|
|
|
|
|
|
|
|
// Run optimizer on constant value to hopefully retrieve a single constant value
|
|
|
|
|
ExpressionPtr optimizedExpr = Optimize(expr, optimizerOptions);
|
|
|
|
|
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
|
|
|
|
|
throw AstError{"expected a constant expression"};
|
|
|
|
|
|
|
|
|
|
return static_cast<ConstantExpression&>(*optimizedExpr).value;
|
|
|
|
|
return static_unique_pointer_cast<T>(ShaderAst::Optimize(node, optimizerOptions));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl)
|
|
|
|
|
|