|
|
|
|
@@ -7,6 +7,7 @@
|
|
|
|
|
#include <Nazara/Core/StackArray.hpp>
|
|
|
|
|
#include <Nazara/Shader/ShaderBuilder.hpp>
|
|
|
|
|
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
|
|
|
|
#include <Nazara/Shader/Ast/AstOptimizer.hpp>
|
|
|
|
|
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
@@ -30,7 +31,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
struct SanitizeVisitor::Context
|
|
|
|
|
{
|
|
|
|
|
struct FunctionData
|
|
|
|
|
struct CurrentFunctionData
|
|
|
|
|
{
|
|
|
|
|
std::optional<ShaderStageType> stageType;
|
|
|
|
|
Bitset<> calledFunctions;
|
|
|
|
|
@@ -38,11 +39,19 @@ namespace Nz::ShaderAst
|
|
|
|
|
FunctionFlags flags;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::size_t nextOptionIndex = 0;
|
|
|
|
|
Options options;
|
|
|
|
|
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
|
|
|
|
|
std::unordered_set<std::string> declaredExternalVar;
|
|
|
|
|
std::unordered_set<UInt64> usedBindingIndexes;
|
|
|
|
|
FunctionData* currentFunction = nullptr;
|
|
|
|
|
std::vector<Identifier> identifiersInScope;
|
|
|
|
|
std::vector<ConstantValue> constantValues;
|
|
|
|
|
std::vector<FunctionData> functions;
|
|
|
|
|
std::vector<IntrinsicType> intrinsics;
|
|
|
|
|
std::vector<StructDescription*> structs;
|
|
|
|
|
std::vector<ExpressionType> variableTypes;
|
|
|
|
|
std::vector<std::size_t> scopeSizes;
|
|
|
|
|
CurrentFunctionData* currentFunction = nullptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Sanitize(Statement& statement, const Options& options, std::string* error)
|
|
|
|
|
@@ -123,14 +132,23 @@ namespace Nz::ShaderAst
|
|
|
|
|
accessIndexPtr = static_cast<AccessIndexExpression*>(indexedExpr.get());
|
|
|
|
|
|
|
|
|
|
std::size_t structIndex = ResolveStruct(exprType);
|
|
|
|
|
assert(structIndex < m_structs.size());
|
|
|
|
|
const StructDescription& s = m_structs[structIndex];
|
|
|
|
|
assert(structIndex < m_context->structs.size());
|
|
|
|
|
const StructDescription* s = m_context->structs[structIndex];
|
|
|
|
|
|
|
|
|
|
auto it = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == identifier; });
|
|
|
|
|
if (it == s.members.end())
|
|
|
|
|
auto it = std::find_if(s->members.begin(), s->members.end(), [&](const auto& field)
|
|
|
|
|
{
|
|
|
|
|
if (field.name != identifier)
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
if (field.cond.HasValue() && !field.cond.GetResultingValue())
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
if (it == s->members.end())
|
|
|
|
|
throw AstError{ "unknown field " + identifier };
|
|
|
|
|
|
|
|
|
|
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s.members.begin(), it))));
|
|
|
|
|
accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s->members.begin(), it))));
|
|
|
|
|
accessIndexPtr->cachedExpressionType = ResolveType(it->type);
|
|
|
|
|
}
|
|
|
|
|
else if (IsVectorType(exprType))
|
|
|
|
|
@@ -419,7 +437,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
for (const auto& param : node.parameters)
|
|
|
|
|
parameters.push_back(CloneExpression(param));
|
|
|
|
|
|
|
|
|
|
auto intrinsic = ShaderBuilder::Intrinsic(m_intrinsics[identifier->index], std::move(parameters));
|
|
|
|
|
auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics[identifier->index], std::move(parameters));
|
|
|
|
|
Validate(*intrinsic);
|
|
|
|
|
|
|
|
|
|
return intrinsic;
|
|
|
|
|
@@ -437,11 +455,11 @@ namespace Nz::ShaderAst
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
// Identifier not found, maybe the function is declared later
|
|
|
|
|
auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; });
|
|
|
|
|
if (it == m_functions.end())
|
|
|
|
|
auto it = std::find_if(m_context->functions.begin(), m_context->functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; });
|
|
|
|
|
if (it == m_context->functions.end())
|
|
|
|
|
throw AstError{ "function " + functionName + " does not exist" };
|
|
|
|
|
|
|
|
|
|
targetFuncIndex = std::distance(m_functions.begin(), it);
|
|
|
|
|
targetFuncIndex = std::distance(m_context->functions.begin(), it);
|
|
|
|
|
|
|
|
|
|
clone->targetFunction = targetFuncIndex;
|
|
|
|
|
}
|
|
|
|
|
@@ -451,7 +469,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex);
|
|
|
|
|
|
|
|
|
|
Validate(*clone, m_functions[targetFuncIndex].node);
|
|
|
|
|
Validate(*clone, m_context->functions[targetFuncIndex].node);
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
@@ -507,18 +525,18 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node)
|
|
|
|
|
{
|
|
|
|
|
MandatoryExpr(node.condition);
|
|
|
|
|
MandatoryExpr(node.truePath);
|
|
|
|
|
MandatoryExpr(node.falsePath);
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<ConditionalExpression>(AstCloner::Clone(node));
|
|
|
|
|
ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition));
|
|
|
|
|
if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean })
|
|
|
|
|
throw AstError{ "expected a boolean value" };
|
|
|
|
|
|
|
|
|
|
const ExpressionType& leftExprType = GetExpressionType(*clone->truePath);
|
|
|
|
|
if (leftExprType != GetExpressionType(*clone->falsePath))
|
|
|
|
|
throw AstError{ "true path type must match false path type" };
|
|
|
|
|
|
|
|
|
|
clone->cachedExpressionType = leftExprType;
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
if (std::get<bool>(conditionValue))
|
|
|
|
|
return AstCloner::Clone(*node.truePath);
|
|
|
|
|
else
|
|
|
|
|
return AstCloner::Clone(*node.falsePath);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
|
|
|
|
|
@@ -537,15 +555,31 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown identifier " + node.identifier };
|
|
|
|
|
|
|
|
|
|
if (identifier->type != Identifier::Type::Variable)
|
|
|
|
|
throw AstError{ "expected variable identifier" };
|
|
|
|
|
switch (identifier->type)
|
|
|
|
|
{
|
|
|
|
|
case Identifier::Type::Constant:
|
|
|
|
|
{
|
|
|
|
|
// Replace IdentifierExpression by ConstantIndexExpression
|
|
|
|
|
auto constantExpr = std::make_unique<ConstantIndexExpression>();
|
|
|
|
|
constantExpr->cachedExpressionType = GetExpressionType(m_context->constantValues[identifier->index]);
|
|
|
|
|
constantExpr->constantId = identifier->index;
|
|
|
|
|
|
|
|
|
|
// Replace IdentifierExpression by VariableExpression
|
|
|
|
|
auto varExpr = std::make_unique<VariableExpression>();
|
|
|
|
|
varExpr->cachedExpressionType = m_variableTypes[identifier->index];
|
|
|
|
|
varExpr->variableId = identifier->index;
|
|
|
|
|
return constantExpr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return varExpr;
|
|
|
|
|
case Identifier::Type::Variable:
|
|
|
|
|
{
|
|
|
|
|
// Replace IdentifierExpression by VariableExpression
|
|
|
|
|
auto varExpr = std::make_unique<VariableExpression>();
|
|
|
|
|
varExpr->cachedExpressionType = m_context->variableTypes[identifier->index];
|
|
|
|
|
varExpr->variableId = identifier->index;
|
|
|
|
|
|
|
|
|
|
return varExpr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
throw AstError{ "expected constant or variable identifier" };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
|
|
|
|
|
@@ -561,26 +595,20 @@ namespace Nz::ShaderAst
|
|
|
|
|
MandatoryExpr(node.truePath);
|
|
|
|
|
MandatoryExpr(node.falsePath);
|
|
|
|
|
|
|
|
|
|
auto condExpr = std::make_unique<ConditionalExpression>();
|
|
|
|
|
condExpr->truePath = CloneExpression(node.truePath);
|
|
|
|
|
condExpr->falsePath = CloneExpression(node.falsePath);
|
|
|
|
|
|
|
|
|
|
const Identifier* identifier = FindIdentifier(node.optionName);
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown option " + node.optionName };
|
|
|
|
|
throw AstError{ "unknown constant " + node.optionName };
|
|
|
|
|
|
|
|
|
|
if (identifier->type != Identifier::Type::Option)
|
|
|
|
|
throw AstError{ "expected option identifier" };
|
|
|
|
|
if (identifier->type != Identifier::Type::Constant)
|
|
|
|
|
throw AstError{ "expected constant identifier" };
|
|
|
|
|
|
|
|
|
|
condExpr->optionIndex = identifier->index;
|
|
|
|
|
if (GetExpressionType(m_context->constantValues[identifier->index]) != ExpressionType{ PrimitiveType::Boolean })
|
|
|
|
|
throw AstError{ "constant is not a boolean" };
|
|
|
|
|
|
|
|
|
|
const ExpressionType& leftExprType = GetExpressionType(*condExpr->truePath);
|
|
|
|
|
if (leftExprType != GetExpressionType(*condExpr->falsePath))
|
|
|
|
|
throw AstError{ "true path type must match false path type" };
|
|
|
|
|
|
|
|
|
|
condExpr->cachedExpressionType = leftExprType;
|
|
|
|
|
|
|
|
|
|
return condExpr;
|
|
|
|
|
if (std::get<bool>(m_context->constantValues[identifier->index]))
|
|
|
|
|
return CloneExpression(node.truePath);
|
|
|
|
|
else
|
|
|
|
|
return CloneExpression(node.falsePath);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
|
|
|
|
|
@@ -687,43 +715,54 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(ConditionalStatement& node)
|
|
|
|
|
{
|
|
|
|
|
MandatoryExpr(node.condition);
|
|
|
|
|
MandatoryStatement(node.statement);
|
|
|
|
|
|
|
|
|
|
PushScope();
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<ConditionalStatement>(AstCloner::Clone(node));
|
|
|
|
|
ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition));
|
|
|
|
|
if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean })
|
|
|
|
|
throw AstError{ "expected a boolean value" };
|
|
|
|
|
|
|
|
|
|
PopScope();
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
if (std::get<bool>(conditionValue))
|
|
|
|
|
return AstCloner::Clone(*node.statement);
|
|
|
|
|
else
|
|
|
|
|
return ShaderBuilder::NoOp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node)
|
|
|
|
|
{
|
|
|
|
|
assert(m_context);
|
|
|
|
|
|
|
|
|
|
for (const auto& extVar : node.externalVars)
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareExternalStatement>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
UInt32 defaultBlockSet = 0;
|
|
|
|
|
if (clone->bindingSet.HasValue())
|
|
|
|
|
defaultBlockSet = ComputeAttributeValue(clone->bindingSet);
|
|
|
|
|
|
|
|
|
|
for (auto& extVar : clone->externalVars)
|
|
|
|
|
{
|
|
|
|
|
if (!extVar.bindingIndex)
|
|
|
|
|
if (!extVar.bindingIndex.HasValue())
|
|
|
|
|
throw AstError{ "external variable " + extVar.name + " requires a binding index" };
|
|
|
|
|
|
|
|
|
|
UInt64 bindingIndex = *extVar.bindingIndex;
|
|
|
|
|
UInt64 bindingSet = extVar.bindingSet.value_or(0);
|
|
|
|
|
if (extVar.bindingSet.HasValue())
|
|
|
|
|
ComputeAttributeValue(extVar.bindingSet);
|
|
|
|
|
else
|
|
|
|
|
extVar.bindingSet = defaultBlockSet;
|
|
|
|
|
|
|
|
|
|
UInt64 bindingSet = extVar.bindingSet.GetResultingValue();
|
|
|
|
|
|
|
|
|
|
UInt64 bindingIndex = ComputeAttributeValue(extVar.bindingIndex);
|
|
|
|
|
|
|
|
|
|
UInt64 bindingKey = bindingSet << 32 | bindingIndex;
|
|
|
|
|
if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end())
|
|
|
|
|
throw AstError{ "Binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" };
|
|
|
|
|
throw AstError{ "binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" };
|
|
|
|
|
|
|
|
|
|
m_context->usedBindingIndexes.insert(bindingKey);
|
|
|
|
|
|
|
|
|
|
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
|
|
|
|
|
throw AstError{ "External variable " + extVar.name + " is already declared" };
|
|
|
|
|
throw AstError{ "external variable " + extVar.name + " is already declared" };
|
|
|
|
|
|
|
|
|
|
m_context->declaredExternalVar.insert(extVar.name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareExternalStatement>(AstCloner::Clone(node));
|
|
|
|
|
for (auto& extVar : clone->externalVars)
|
|
|
|
|
{
|
|
|
|
|
extVar.type = ResolveType(extVar.type);
|
|
|
|
|
|
|
|
|
|
ExpressionType varType;
|
|
|
|
|
@@ -746,9 +785,23 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node)
|
|
|
|
|
{
|
|
|
|
|
if (node.entryStage)
|
|
|
|
|
auto clone = std::make_unique<DeclareFunctionStatement>();
|
|
|
|
|
clone->name = node.name;
|
|
|
|
|
clone->parameters = node.parameters;
|
|
|
|
|
clone->returnType = ResolveType(node.returnType);
|
|
|
|
|
|
|
|
|
|
if (node.depthWrite.HasValue())
|
|
|
|
|
clone->depthWrite = ComputeAttributeValue(node.depthWrite);
|
|
|
|
|
|
|
|
|
|
if (node.earlyFragmentTests.HasValue())
|
|
|
|
|
clone->earlyFragmentTests = ComputeAttributeValue(node.earlyFragmentTests);
|
|
|
|
|
|
|
|
|
|
if (node.entryStage.HasValue())
|
|
|
|
|
clone->entryStage = ComputeAttributeValue(node.entryStage);
|
|
|
|
|
|
|
|
|
|
if (clone->entryStage.HasValue())
|
|
|
|
|
{
|
|
|
|
|
ShaderStageType stageType = *node.entryStage;
|
|
|
|
|
ShaderStageType stageType = clone->entryStage.GetResultingValue();
|
|
|
|
|
|
|
|
|
|
if (m_context->entryFunctions[UnderlyingCast(stageType)])
|
|
|
|
|
throw AstError{ "the same entry type has been defined multiple times" };
|
|
|
|
|
@@ -760,26 +813,17 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
if (stageType != ShaderStageType::Fragment)
|
|
|
|
|
{
|
|
|
|
|
if (node.depthWrite.has_value())
|
|
|
|
|
if (node.depthWrite.HasValue())
|
|
|
|
|
throw AstError{ "only fragment entry-points can have the depth_write attribute" };
|
|
|
|
|
|
|
|
|
|
if (node.earlyFragmentTests.has_value())
|
|
|
|
|
if (node.earlyFragmentTests.HasValue())
|
|
|
|
|
throw AstError{ "only functions with entry(frag) attribute can have the early_fragments_tests attribute" };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto clone = std::make_unique<DeclareFunctionStatement>();
|
|
|
|
|
clone->depthWrite = node.depthWrite;
|
|
|
|
|
clone->earlyFragmentTests = node.earlyFragmentTests;
|
|
|
|
|
clone->entryStage = node.entryStage;
|
|
|
|
|
clone->name = node.name;
|
|
|
|
|
clone->optionName = node.optionName;
|
|
|
|
|
clone->parameters = node.parameters;
|
|
|
|
|
clone->returnType = ResolveType(node.returnType);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Context::FunctionData tempFuncData;
|
|
|
|
|
tempFuncData.stageType = node.entryStage;
|
|
|
|
|
Context::CurrentFunctionData tempFuncData;
|
|
|
|
|
if (node.entryStage.HasValue())
|
|
|
|
|
tempFuncData.stageType = node.entryStage.GetResultingValue();
|
|
|
|
|
|
|
|
|
|
m_context->currentFunction = &tempFuncData;
|
|
|
|
|
|
|
|
|
|
@@ -803,31 +847,17 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
m_context->currentFunction = nullptr;
|
|
|
|
|
|
|
|
|
|
if (clone->earlyFragmentTests.has_value() && *clone->earlyFragmentTests)
|
|
|
|
|
if (clone->earlyFragmentTests.HasValue() && clone->earlyFragmentTests.GetResultingValue())
|
|
|
|
|
{
|
|
|
|
|
//TODO: warning and disable early fragment tests
|
|
|
|
|
throw AstError{ "discard is not compatible with early fragment tests" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!clone->optionName.empty())
|
|
|
|
|
{
|
|
|
|
|
const Identifier* identifier = FindIdentifier(node.optionName);
|
|
|
|
|
if (!identifier)
|
|
|
|
|
throw AstError{ "unknown option " + node.optionName };
|
|
|
|
|
|
|
|
|
|
if (identifier->type != Identifier::Type::Option)
|
|
|
|
|
throw AstError{ "expected option identifier" };
|
|
|
|
|
|
|
|
|
|
std::size_t optionIndex = identifier->index;
|
|
|
|
|
|
|
|
|
|
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node == &node; });
|
|
|
|
|
assert(it != m_functions.end());
|
|
|
|
|
auto it = std::find_if(m_context->functions.begin(), m_context->functions.end(), [&](const auto& funcData) { return funcData.node == &node; });
|
|
|
|
|
assert(it != m_context->functions.end());
|
|
|
|
|
assert(!it->defined);
|
|
|
|
|
|
|
|
|
|
std::size_t funcIndex = std::distance(m_functions.begin(), it);
|
|
|
|
|
std::size_t funcIndex = std::distance(m_context->functions.begin(), it);
|
|
|
|
|
|
|
|
|
|
clone->funcIndex = funcIndex;
|
|
|
|
|
|
|
|
|
|
@@ -836,8 +866,8 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i))
|
|
|
|
|
{
|
|
|
|
|
assert(i < m_functions.size());
|
|
|
|
|
auto& targetFunc = m_functions[i];
|
|
|
|
|
assert(i < m_context->functions.size());
|
|
|
|
|
auto& targetFunc = m_context->functions[i];
|
|
|
|
|
targetFunc.calledByFunctions.UnboundedSet(funcIndex);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -854,7 +884,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (clone->initialValue && clone->optType != GetExpressionType(*clone->initialValue))
|
|
|
|
|
throw AstError{ "option " + clone->optName + " initial expression must be of the same type than the option" };
|
|
|
|
|
|
|
|
|
|
clone->optIndex = RegisterOption(clone->optName, clone->optType);
|
|
|
|
|
std::size_t optionIndex = m_context->nextOptionIndex++;
|
|
|
|
|
|
|
|
|
|
clone->optIndex = RegisterConstant(clone->optName, TestBit(m_context->options.enabledOptions, optionIndex));
|
|
|
|
|
|
|
|
|
|
if (m_context->options.removeOptionDeclaration)
|
|
|
|
|
return ShaderBuilder::NoOp();
|
|
|
|
|
@@ -864,22 +896,33 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node)
|
|
|
|
|
{
|
|
|
|
|
std::unordered_set<std::string> declaredMembers;
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
for (auto& member : node.description.members)
|
|
|
|
|
std::unordered_set<std::string> declaredMembers;
|
|
|
|
|
for (auto& member : clone->description.members)
|
|
|
|
|
{
|
|
|
|
|
if (member.cond.HasValue())
|
|
|
|
|
{
|
|
|
|
|
member.cond = ComputeAttributeValue(member.cond);
|
|
|
|
|
if (!member.cond.GetResultingValue())
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (member.builtin.HasValue())
|
|
|
|
|
member.builtin = ComputeAttributeValue(member.builtin);
|
|
|
|
|
|
|
|
|
|
if (member.locationIndex.HasValue())
|
|
|
|
|
member.locationIndex = ComputeAttributeValue(member.locationIndex);
|
|
|
|
|
|
|
|
|
|
if (declaredMembers.find(member.name) != declaredMembers.end())
|
|
|
|
|
throw AstError{ "struct member " + member.name + " found multiple time" };
|
|
|
|
|
|
|
|
|
|
declaredMembers.insert(member.name);
|
|
|
|
|
|
|
|
|
|
member.type = ResolveType(member.type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
|
|
|
|
|
|
|
|
|
|
for (auto& member : clone->description.members)
|
|
|
|
|
member.type = ResolveType(member.type);
|
|
|
|
|
|
|
|
|
|
clone->structIndex = RegisterStruct(clone->description.name, clone->description);
|
|
|
|
|
clone->structIndex = RegisterStruct(clone->description.name, &clone->description);
|
|
|
|
|
|
|
|
|
|
SanitizeIdentifier(clone->description.name);
|
|
|
|
|
|
|
|
|
|
@@ -951,6 +994,15 @@ namespace Nz::ShaderAst
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
|
|
|
|
|
{
|
|
|
|
|
auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
|
|
|
|
|
if (it == m_context->identifiersInScope.rend())
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
|
|
return &*it;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Expression& SanitizeVisitor::MandatoryExpr(ExpressionPtr& node)
|
|
|
|
|
{
|
|
|
|
|
if (!node)
|
|
|
|
|
@@ -969,20 +1021,69 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::PushScope()
|
|
|
|
|
{
|
|
|
|
|
m_scopeSizes.push_back(m_identifiersInScope.size());
|
|
|
|
|
m_context->scopeSizes.push_back(m_context->identifiersInScope.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::PopScope()
|
|
|
|
|
{
|
|
|
|
|
assert(!m_scopeSizes.empty());
|
|
|
|
|
m_identifiersInScope.resize(m_scopeSizes.back());
|
|
|
|
|
m_scopeSizes.pop_back();
|
|
|
|
|
assert(!m_context->scopeSizes.empty());
|
|
|
|
|
m_context->identifiersInScope.resize(m_context->scopeSizes.back());
|
|
|
|
|
m_context->scopeSizes.pop_back();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
|
const T& SanitizeVisitor::ComputeAttributeValue(AttributeValue<T>& attribute)
|
|
|
|
|
{
|
|
|
|
|
if (!attribute.HasValue())
|
|
|
|
|
throw AstError{"attribute expected a value"};
|
|
|
|
|
|
|
|
|
|
if (attribute.IsExpression())
|
|
|
|
|
{
|
|
|
|
|
ConstantValue value = ComputeConstantValue(*attribute.GetExpression());
|
|
|
|
|
if constexpr (TypeListFind<ConstantTypes, T>)
|
|
|
|
|
{
|
|
|
|
|
if (!std::holds_alternative<T>(value))
|
|
|
|
|
{
|
|
|
|
|
// HAAAAAX
|
|
|
|
|
if (std::holds_alternative<Int32>(value) && std::is_same_v<T, UInt32>)
|
|
|
|
|
attribute = static_cast<UInt32>(std::get<Int32>(value));
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "unexpected attribute type" };
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
attribute = std::get<T>(value);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "unexpected expression for this type" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert(attribute.IsResultingValue());
|
|
|
|
|
return attribute.GetResultingValue();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr)
|
|
|
|
|
{
|
|
|
|
|
AstOptimizer::Options optimizerOptions;
|
|
|
|
|
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
|
|
|
|
|
{
|
|
|
|
|
assert(constantId < m_context->constantValues.size());
|
|
|
|
|
return m_context->constantValues[constantId];
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
optimizerOptions.enabledOptions = m_context->options.enabledOptions;
|
|
|
|
|
|
|
|
|
|
// Run optimizer on constant value to hopefully retrieve a single constant value
|
|
|
|
|
ExpressionPtr optimizedExpr = Optimize(expr, optimizerOptions);
|
|
|
|
|
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
|
|
|
|
|
throw AstError{"expected a constant expression"};
|
|
|
|
|
|
|
|
|
|
return static_cast<ConstantExpression&>(*optimizedExpr).value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl)
|
|
|
|
|
{
|
|
|
|
|
std::size_t functionIndex = m_functions.size();
|
|
|
|
|
auto& funcData = m_functions.emplace_back();
|
|
|
|
|
std::size_t functionIndex = m_context->functions.size();
|
|
|
|
|
auto& funcData = m_context->functions.emplace_back();
|
|
|
|
|
funcData.node = &funcDecl;
|
|
|
|
|
|
|
|
|
|
return functionIndex;
|
|
|
|
|
@@ -990,8 +1091,8 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
|
|
|
|
|
{
|
|
|
|
|
assert(funcIndex < m_functions.size());
|
|
|
|
|
auto& funcData = m_functions[funcIndex];
|
|
|
|
|
assert(funcIndex < m_context->functions.size());
|
|
|
|
|
auto& funcData = m_context->functions[funcIndex];
|
|
|
|
|
assert(funcData.defined);
|
|
|
|
|
|
|
|
|
|
funcData.flags |= flags;
|
|
|
|
|
@@ -999,11 +1100,28 @@ namespace Nz::ShaderAst
|
|
|
|
|
for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i))
|
|
|
|
|
PropagateFunctionFlags(i, funcData.flags, seen);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value)
|
|
|
|
|
{
|
|
|
|
|
if (FindIdentifier(name))
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t constantIndex = m_context->constantValues.size();
|
|
|
|
|
m_context->constantValues.emplace_back(std::move(value));
|
|
|
|
|
|
|
|
|
|
m_context->identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
constantIndex,
|
|
|
|
|
Identifier::Type::Constant
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return constantIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto SanitizeVisitor::RegisterFunction(std::size_t functionIndex) -> FunctionData&
|
|
|
|
|
{
|
|
|
|
|
assert(m_functions.size() >= functionIndex);
|
|
|
|
|
auto& funcData = m_functions[functionIndex];
|
|
|
|
|
assert(m_context->functions.size() >= functionIndex);
|
|
|
|
|
auto& funcData = m_context->functions[functionIndex];
|
|
|
|
|
assert(!funcData.defined);
|
|
|
|
|
funcData.defined = true;
|
|
|
|
|
|
|
|
|
|
@@ -1012,10 +1130,10 @@ namespace Nz::ShaderAst
|
|
|
|
|
bool duplicate = true;
|
|
|
|
|
|
|
|
|
|
// Functions cannot be declared twice, except for entry ones if their stages are different
|
|
|
|
|
if (funcData.node->entryStage && identifier->type == Identifier::Type::Function)
|
|
|
|
|
if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function)
|
|
|
|
|
{
|
|
|
|
|
auto& otherFunction = m_functions[identifier->index];
|
|
|
|
|
if (funcData.node->entryStage != otherFunction.node->entryStage)
|
|
|
|
|
auto& otherFunction = m_context->functions[identifier->index];
|
|
|
|
|
if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue())
|
|
|
|
|
duplicate = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1023,7 +1141,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
throw AstError{ funcData.node->name + " is already used" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
m_context->identifiersInScope.push_back({
|
|
|
|
|
funcData.node->name,
|
|
|
|
|
functionIndex,
|
|
|
|
|
Identifier::Type::Function
|
|
|
|
|
@@ -1037,10 +1155,10 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (FindIdentifier(name))
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t intrinsicIndex = m_intrinsics.size();
|
|
|
|
|
m_intrinsics.push_back(type);
|
|
|
|
|
std::size_t intrinsicIndex = m_context->intrinsics.size();
|
|
|
|
|
m_context->intrinsics.push_back(type);
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
m_context->identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
intrinsicIndex,
|
|
|
|
|
Identifier::Type::Intrinsic
|
|
|
|
|
@@ -1049,32 +1167,15 @@ namespace Nz::ShaderAst
|
|
|
|
|
return intrinsicIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterOption(std::string name, ExpressionType type)
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description)
|
|
|
|
|
{
|
|
|
|
|
if (FindIdentifier(name))
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t optionIndex = m_options.size();
|
|
|
|
|
m_options.emplace_back(std::move(type));
|
|
|
|
|
std::size_t structIndex = m_context->structs.size();
|
|
|
|
|
m_context->structs.emplace_back(description);
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
optionIndex,
|
|
|
|
|
Identifier::Type::Option
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return optionIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description)
|
|
|
|
|
{
|
|
|
|
|
if (FindIdentifier(name))
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t structIndex = m_structs.size();
|
|
|
|
|
m_structs.emplace_back(std::move(description));
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
m_context->identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
structIndex,
|
|
|
|
|
Identifier::Type::Struct
|
|
|
|
|
@@ -1089,10 +1190,10 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (auto* identifier = FindIdentifier(name); identifier && identifier->type != Identifier::Type::Variable)
|
|
|
|
|
throw AstError{ name + " is already used" };
|
|
|
|
|
|
|
|
|
|
std::size_t varIndex = m_variableTypes.size();
|
|
|
|
|
m_variableTypes.emplace_back(std::move(type));
|
|
|
|
|
std::size_t varIndex = m_context->variableTypes.size();
|
|
|
|
|
m_context->variableTypes.emplace_back(std::move(type));
|
|
|
|
|
|
|
|
|
|
m_identifiersInScope.push_back({
|
|
|
|
|
m_context->identifiersInScope.push_back({
|
|
|
|
|
std::move(name),
|
|
|
|
|
varIndex,
|
|
|
|
|
Identifier::Type::Variable
|
|
|
|
|
@@ -1106,17 +1207,17 @@ namespace Nz::ShaderAst
|
|
|
|
|
// Once every function is known, we can propagate flags
|
|
|
|
|
|
|
|
|
|
Bitset<> seen;
|
|
|
|
|
for (std::size_t funcIndex = 0; funcIndex < m_functions.size(); ++funcIndex)
|
|
|
|
|
for (std::size_t funcIndex = 0; funcIndex < m_context->functions.size(); ++funcIndex)
|
|
|
|
|
{
|
|
|
|
|
auto& funcData = m_functions[funcIndex];
|
|
|
|
|
auto& funcData = m_context->functions[funcIndex];
|
|
|
|
|
|
|
|
|
|
PropagateFunctionFlags(funcIndex, funcData.flags, seen);
|
|
|
|
|
seen.Clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const FunctionData& funcData : m_functions)
|
|
|
|
|
for (const FunctionData& funcData : m_context->functions)
|
|
|
|
|
{
|
|
|
|
|
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage && *funcData.node->entryStage != ShaderStageType::Fragment)
|
|
|
|
|
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
|
|
|
|
|
throw AstError{ "discard can only be used in the fragment stage" };
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -1246,18 +1347,18 @@ namespace Nz::ShaderAst
|
|
|
|
|
auto& indexExpr = node.indices[i];
|
|
|
|
|
|
|
|
|
|
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
|
|
|
|
|
if (indexExpr->GetType() != NodeType::ConstantExpression)
|
|
|
|
|
throw AstError{ "struct can only be accessed with constant indices" };
|
|
|
|
|
if (indexExpr->GetType() != NodeType::ConstantExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
|
|
|
|
|
throw AstError{ "struct can only be accessed with constant i32 indices" };
|
|
|
|
|
|
|
|
|
|
ConstantExpression& constantExpr = static_cast<ConstantExpression&>(*indexExpr);
|
|
|
|
|
|
|
|
|
|
Int32 index = std::get<Int32>(constantExpr.value);
|
|
|
|
|
|
|
|
|
|
std::size_t structIndex = ResolveStruct(exprType);
|
|
|
|
|
assert(structIndex < m_structs.size());
|
|
|
|
|
const StructDescription& s = m_structs[structIndex];
|
|
|
|
|
assert(structIndex < m_context->structs.size());
|
|
|
|
|
const StructDescription* s = m_context->structs[structIndex];
|
|
|
|
|
|
|
|
|
|
exprType = ResolveType(s.members[index].type);
|
|
|
|
|
exprType = ResolveType(s->members[index].type);
|
|
|
|
|
}
|
|
|
|
|
else if (IsMatrixType(exprType))
|
|
|
|
|
{
|
|
|
|
|
@@ -1283,7 +1384,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration)
|
|
|
|
|
{
|
|
|
|
|
if (referenceDeclaration->entryStage)
|
|
|
|
|
if (referenceDeclaration->entryStage.HasValue())
|
|
|
|
|
throw AstError{ referenceDeclaration->name + " is an entry function which cannot be called by the program" };
|
|
|
|
|
|
|
|
|
|
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
|
|
|
|
|