Shader: Implement const if and const values

This commit is contained in:
Jérôme Leclercq
2021-07-07 21:38:23 +02:00
parent d679eccb43
commit 1f6937ab1b
28 changed files with 315 additions and 60 deletions

View File

@@ -40,6 +40,7 @@ namespace Nz::ShaderAst
{
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
clone->isConst = node.isConst;
for (auto& cond : node.condStatements)
{
@@ -62,6 +63,17 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(DeclareConstStatement& node)
{
auto clone = std::make_unique<DeclareConstStatement>();
clone->constIndex = node.constIndex;
clone->name = node.name;
clone->type = node.type;
clone->expression = CloneExpression(node.expression);
return clone;
}
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareExternalStatement>();

View File

@@ -844,7 +844,10 @@ namespace Nz::ShaderAst
if (!m_options.constantQueryCallback)
return AstCloner::Clone(node);
return ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
constant->cachedExpressionType = GetExpressionType(constant->value);
return constant;
}
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)

View File

@@ -122,6 +122,12 @@ namespace Nz::ShaderAst
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareConstStatement& node)
{
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareExternalStatement& /*node*/)
{
/* Nothing to do */

View File

@@ -208,6 +208,7 @@ namespace Nz::ShaderAst
}
Node(node.elseStatement);
Value(node.isConst);
}
void AstSerializerBase::Serialize(ConditionalStatement& node)
@@ -232,6 +233,14 @@ namespace Nz::ShaderAst
}
}
void AstSerializerBase::Serialize(DeclareConstStatement& node)
{
OptVal(node.constIndex);
Value(node.name);
Type(node.type);
Node(node.expression);
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);

View File

@@ -577,6 +577,18 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(ConstantIndexExpression& node)
{
if (node.constantId >= m_context->constantValues.size())
throw AstError{ "invalid constant index " + std::to_string(node.constantId) };
// Replace by constant value
auto constant = ShaderBuilder::Constant(m_context->constantValues[node.constantId]);
constant->cachedExpressionType = GetExpressionType(constant->value);
return constant;
}
ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node)
{
assert(m_context);
@@ -712,11 +724,46 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(VariableExpression& node)
{
if (node.variableId >= m_context->variableTypes.size())
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
node.cachedExpressionType = m_context->variableTypes[node.variableId];
return AstCloner::Clone(node);
}
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
{
if (node.isConst)
{
// Evaluate every condition at compilation and select the right statement
for (auto& cond : node.condStatements)
{
MandatoryExpr(cond.condition);
ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition));
if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean })
throw AstError{ "expected a boolean value" };
if (std::get<bool>(conditionValue))
return AstCloner::Clone(*cond.statement);
}
// Every condition failed, fallback to else if any
if (node.elseStatement)
return AstCloner::Clone(*node.elseStatement);
else
return ShaderBuilder::NoOp();
}
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
if (!m_context->currentFunction)
throw AstError{ "non-const branching statements can only exist inside a function" };
for (auto& cond : node.condStatements)
{
PushScope();
@@ -758,6 +805,31 @@ namespace Nz::ShaderAst
return ShaderBuilder::NoOp();
}
StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node)
{
auto clone = static_unique_pointer_cast<DeclareConstStatement>(AstCloner::Clone(node));
if (!clone->expression)
throw AstError{ "const variables must have an expression" };
clone->expression = Optimize(*clone->expression);
if (clone->expression->GetType() != NodeType::ConstantExpression)
throw AstError{ "const variable must have constant expressions " };
const ConstantValue& value = static_cast<ConstantExpression&>(*clone->expression).value;
ExpressionType expressionType = ResolveType(GetExpressionType(value));
if (!IsNoType(clone->type) && ResolveType(clone->type) != expressionType)
throw AstError{ "constant expression doesn't match type" };
clone->type = expressionType;
clone->constIndex = RegisterConstant(clone->name, value);
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node)
{
assert(m_context);
@@ -815,6 +887,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "a function cannot be defined inside another function" };
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->name = node.name;
clone->parameters = node.parameters;
@@ -908,6 +983,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "options must be declared outside of functions" };
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
clone->optType = ResolveType(clone->optType);
@@ -926,6 +1004,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "structs must be declared outside of functions" };
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
std::unordered_set<std::string> declaredMembers;
@@ -961,6 +1042,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node)
{
if (!m_context->currentFunction)
throw AstError{ "global variables outside of external blocks are forbidden" };
auto clone = static_unique_pointer_cast<DeclareVariableStatement>(AstCloner::Clone(node));
if (IsNoType(clone->varType))
{
@@ -1092,6 +1176,17 @@ namespace Nz::ShaderAst
}
ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr)
{
// Run optimizer on constant value to hopefully retrieve a single constant value
ExpressionPtr optimizedExpr = Optimize(expr);
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
throw AstError{"expected a constant expression"};
return static_cast<ConstantExpression&>(*optimizedExpr).value;
}
template<typename T>
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
{
AstOptimizer::Options optimizerOptions;
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
@@ -1103,11 +1198,7 @@ namespace Nz::ShaderAst
optimizerOptions.enabledOptions = m_context->options.enabledOptions;
// Run optimizer on constant value to hopefully retrieve a single constant value
ExpressionPtr optimizedExpr = Optimize(expr, optimizerOptions);
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
throw AstError{"expected a constant expression"};
return static_cast<ConstantExpression&>(*optimizedExpr).value;
return static_unique_pointer_cast<T>(ShaderAst::Optimize(node, optimizerOptions));
}
std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl)