Shader: Add support for exported functions

This commit is contained in:
Jérôme Leclercq
2022-03-14 18:00:02 +01:00
parent 1c4ce75aa0
commit bf44672354
14 changed files with 170 additions and 67 deletions

View File

@@ -124,6 +124,7 @@ namespace Nz::ShaderAst
clone->earlyFragmentTests = Clone(node.earlyFragmentTests);
clone->entryStage = Clone(node.entryStage);
clone->funcIndex = node.funcIndex;
clone->isExported = Clone(node.isExported);
clone->name = node.name;
clone->returnType = Clone(node.returnType);

View File

@@ -14,6 +14,15 @@ namespace Nz::ShaderAst
statement.Visit(*this);
}
void AstExportVisitor::Visit(DeclareFunctionStatement& node)
{
if (!node.isExported.HasValue() || !node.isExported.GetResultingValue())
return;
if (m_callbacks->onExportedFunc)
m_callbacks->onExportedFunc(node);
}
void AstExportVisitor::Visit(DeclareStructStatement& node)
{
if (!node.isExported.HasValue() || !node.isExported.GetResultingValue())

View File

@@ -260,6 +260,7 @@ namespace Nz::ShaderAst
ExprValue(node.depthWrite);
ExprValue(node.earlyFragmentTests);
ExprValue(node.entryStage);
ExprValue(node.isExported);
OptVal(node.funcIndex);
Container(node.parameters);

View File

@@ -496,16 +496,28 @@ namespace Nz::ShaderAst
{
ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction));
const ExpressionType& targetExprType = GetExpressionType(*targetExpr);
const ExpressionType& resolvedType = ResolveAlias(targetExprType);
if (IsFunctionType(targetExprType))
if (IsFunctionType(resolvedType))
{
if (!m_context->currentFunction)
throw AstError{ "function calls must happen inside a function" };
if (targetExpr->GetType() != NodeType::FunctionExpression)
throw AstError{ "expected function expression" };
std::size_t targetFuncIndex;
if (targetExpr->GetType() == NodeType::FunctionExpression)
targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
else if (targetExpr->GetType() == NodeType::AliasValueExpression)
{
const auto& alias = static_cast<AliasValueExpression&>(*targetExpr);
std::size_t targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId));
if (targetIdentifier->category != IdentifierCategory::Function)
throw AstError{ "expected function expression" };
targetFuncIndex = targetIdentifier->index;
}
else
throw AstError{ "expected function expression" };
auto clone = std::make_unique<CallFunctionExpression>();
clone->targetFunction = std::move(targetExpr);
@@ -520,7 +532,7 @@ namespace Nz::ShaderAst
return clone;
}
else if (IsIntrinsicFunctionType(targetExprType))
else if (IsIntrinsicFunctionType(resolvedType))
{
if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression)
throw AstError{ "expected intrinsic function expression" };
@@ -538,9 +550,9 @@ namespace Nz::ShaderAst
return intrinsic;
}
else if (IsMethodType(targetExprType))
else if (IsMethodType(resolvedType))
{
const MethodType& methodType = std::get<MethodType>(targetExprType);
const MethodType& methodType = std::get<MethodType>(resolvedType);
std::vector<ExpressionPtr> parameters;
parameters.reserve(node.parameters.size() + 1);
@@ -987,6 +999,9 @@ namespace Nz::ShaderAst
if (node.entryStage.HasValue())
clone->entryStage = ComputeExprValue(node.entryStage);
if (node.isExported.HasValue())
clone->isExported = ComputeExprValue(node.isExported);
if (clone->entryStage.HasValue())
{
ShaderStageType stageType = clone->entryStage.GetResultingValue();
@@ -1468,7 +1483,7 @@ namespace Nz::ShaderAst
hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size());
hasher.End();
std::string identifier = "__" + hasher.End().ToHex().substr(0, 8);
std::string identifier = "_" + hasher.End().ToHex().substr(0, 8);
// Load new module
auto moduleEnvironment = std::make_shared<Environment>();
@@ -1519,6 +1534,19 @@ namespace Nz::ShaderAst
std::vector<DeclareAliasStatementPtr> aliasStatements;
AstExportVisitor::Callbacks callbacks;
callbacks.onExportedFunc = [&](DeclareFunctionStatement& node)
{
assert(node.funcIndex);
moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex);
if (!exportedSet.usedFunctions.UnboundedTest(*node.funcIndex))
{
exportedSet.usedFunctions.UnboundedSet(*node.funcIndex);
aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.name, ShaderBuilder::Function(*node.funcIndex)));
}
};
callbacks.onExportedStruct = [&](DeclareStructStatement& node)
{
assert(node.structIndex);
@@ -2475,13 +2503,18 @@ namespace Nz::ShaderAst
std::size_t structIndex = ResolveStruct(exprType);
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
}
else if (IsFunctionType(exprType))
{
std::size_t funcIndex = std::get<FunctionType>(exprType).funcIndex;
node.aliasIndex = RegisterAlias(node.name, { funcIndex, IdentifierCategory::Function }, node.aliasIndex);
}
else if (IsAliasType(exprType))
{
const AliasType& alias = std::get<AliasType>(exprType);
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
}
else
throw AstError{ "for now, only structs can be aliased" };
throw AstError{ "for now, only aliases, functions and structs can be aliased" };
}
void SanitizeVisitor::Validate(WhileStatement& node)
@@ -2659,10 +2692,22 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CallFunctionExpression& node)
{
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
assert(std::holds_alternative<FunctionType>(targetFuncType));
std::size_t targetFuncIndex;
if (node.targetFunction->GetType() == NodeType::FunctionExpression)
targetFuncIndex = static_cast<FunctionExpression&>(*node.targetFunction).funcId;
else if (node.targetFunction->GetType() == NodeType::AliasValueExpression)
{
const auto& alias = static_cast<AliasValueExpression&>(*node.targetFunction);
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId));
if (targetIdentifier->category != IdentifierCategory::Function)
throw AstError{ "expected function expression" };
targetFuncIndex = targetIdentifier->index;
}
else
throw AstError{ "expected function expression" };
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
auto& funcData = m_context->functions.Retrieve(targetFuncIndex);
const DeclareFunctionStatement* referenceDeclaration = funcData.node;