Shader: Add support for exported functions
This commit is contained in:
@@ -124,6 +124,7 @@ namespace Nz::ShaderAst
|
||||
clone->earlyFragmentTests = Clone(node.earlyFragmentTests);
|
||||
clone->entryStage = Clone(node.entryStage);
|
||||
clone->funcIndex = node.funcIndex;
|
||||
clone->isExported = Clone(node.isExported);
|
||||
clone->name = node.name;
|
||||
clone->returnType = Clone(node.returnType);
|
||||
|
||||
|
||||
@@ -14,6 +14,15 @@ namespace Nz::ShaderAst
|
||||
statement.Visit(*this);
|
||||
}
|
||||
|
||||
void AstExportVisitor::Visit(DeclareFunctionStatement& node)
|
||||
{
|
||||
if (!node.isExported.HasValue() || !node.isExported.GetResultingValue())
|
||||
return;
|
||||
|
||||
if (m_callbacks->onExportedFunc)
|
||||
m_callbacks->onExportedFunc(node);
|
||||
}
|
||||
|
||||
void AstExportVisitor::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
if (!node.isExported.HasValue() || !node.isExported.GetResultingValue())
|
||||
|
||||
@@ -260,6 +260,7 @@ namespace Nz::ShaderAst
|
||||
ExprValue(node.depthWrite);
|
||||
ExprValue(node.earlyFragmentTests);
|
||||
ExprValue(node.entryStage);
|
||||
ExprValue(node.isExported);
|
||||
OptVal(node.funcIndex);
|
||||
|
||||
Container(node.parameters);
|
||||
|
||||
@@ -496,16 +496,28 @@ namespace Nz::ShaderAst
|
||||
{
|
||||
ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction));
|
||||
const ExpressionType& targetExprType = GetExpressionType(*targetExpr);
|
||||
const ExpressionType& resolvedType = ResolveAlias(targetExprType);
|
||||
|
||||
if (IsFunctionType(targetExprType))
|
||||
if (IsFunctionType(resolvedType))
|
||||
{
|
||||
if (!m_context->currentFunction)
|
||||
throw AstError{ "function calls must happen inside a function" };
|
||||
|
||||
if (targetExpr->GetType() != NodeType::FunctionExpression)
|
||||
throw AstError{ "expected function expression" };
|
||||
std::size_t targetFuncIndex;
|
||||
if (targetExpr->GetType() == NodeType::FunctionExpression)
|
||||
targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
|
||||
else if (targetExpr->GetType() == NodeType::AliasValueExpression)
|
||||
{
|
||||
const auto& alias = static_cast<AliasValueExpression&>(*targetExpr);
|
||||
|
||||
std::size_t targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
|
||||
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId));
|
||||
if (targetIdentifier->category != IdentifierCategory::Function)
|
||||
throw AstError{ "expected function expression" };
|
||||
|
||||
targetFuncIndex = targetIdentifier->index;
|
||||
}
|
||||
else
|
||||
throw AstError{ "expected function expression" };
|
||||
|
||||
auto clone = std::make_unique<CallFunctionExpression>();
|
||||
clone->targetFunction = std::move(targetExpr);
|
||||
@@ -520,7 +532,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
return clone;
|
||||
}
|
||||
else if (IsIntrinsicFunctionType(targetExprType))
|
||||
else if (IsIntrinsicFunctionType(resolvedType))
|
||||
{
|
||||
if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression)
|
||||
throw AstError{ "expected intrinsic function expression" };
|
||||
@@ -538,9 +550,9 @@ namespace Nz::ShaderAst
|
||||
|
||||
return intrinsic;
|
||||
}
|
||||
else if (IsMethodType(targetExprType))
|
||||
else if (IsMethodType(resolvedType))
|
||||
{
|
||||
const MethodType& methodType = std::get<MethodType>(targetExprType);
|
||||
const MethodType& methodType = std::get<MethodType>(resolvedType);
|
||||
|
||||
std::vector<ExpressionPtr> parameters;
|
||||
parameters.reserve(node.parameters.size() + 1);
|
||||
@@ -987,6 +999,9 @@ namespace Nz::ShaderAst
|
||||
if (node.entryStage.HasValue())
|
||||
clone->entryStage = ComputeExprValue(node.entryStage);
|
||||
|
||||
if (node.isExported.HasValue())
|
||||
clone->isExported = ComputeExprValue(node.isExported);
|
||||
|
||||
if (clone->entryStage.HasValue())
|
||||
{
|
||||
ShaderStageType stageType = clone->entryStage.GetResultingValue();
|
||||
@@ -1468,7 +1483,7 @@ namespace Nz::ShaderAst
|
||||
hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size());
|
||||
hasher.End();
|
||||
|
||||
std::string identifier = "__" + hasher.End().ToHex().substr(0, 8);
|
||||
std::string identifier = "_" + hasher.End().ToHex().substr(0, 8);
|
||||
|
||||
// Load new module
|
||||
auto moduleEnvironment = std::make_shared<Environment>();
|
||||
@@ -1519,6 +1534,19 @@ namespace Nz::ShaderAst
|
||||
std::vector<DeclareAliasStatementPtr> aliasStatements;
|
||||
|
||||
AstExportVisitor::Callbacks callbacks;
|
||||
callbacks.onExportedFunc = [&](DeclareFunctionStatement& node)
|
||||
{
|
||||
assert(node.funcIndex);
|
||||
|
||||
moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex);
|
||||
|
||||
if (!exportedSet.usedFunctions.UnboundedTest(*node.funcIndex))
|
||||
{
|
||||
exportedSet.usedFunctions.UnboundedSet(*node.funcIndex);
|
||||
aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.name, ShaderBuilder::Function(*node.funcIndex)));
|
||||
}
|
||||
};
|
||||
|
||||
callbacks.onExportedStruct = [&](DeclareStructStatement& node)
|
||||
{
|
||||
assert(node.structIndex);
|
||||
@@ -2475,13 +2503,18 @@ namespace Nz::ShaderAst
|
||||
std::size_t structIndex = ResolveStruct(exprType);
|
||||
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
|
||||
}
|
||||
else if (IsFunctionType(exprType))
|
||||
{
|
||||
std::size_t funcIndex = std::get<FunctionType>(exprType).funcIndex;
|
||||
node.aliasIndex = RegisterAlias(node.name, { funcIndex, IdentifierCategory::Function }, node.aliasIndex);
|
||||
}
|
||||
else if (IsAliasType(exprType))
|
||||
{
|
||||
const AliasType& alias = std::get<AliasType>(exprType);
|
||||
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
|
||||
}
|
||||
else
|
||||
throw AstError{ "for now, only structs can be aliased" };
|
||||
throw AstError{ "for now, only aliases, functions and structs can be aliased" };
|
||||
}
|
||||
|
||||
void SanitizeVisitor::Validate(WhileStatement& node)
|
||||
@@ -2659,10 +2692,22 @@ namespace Nz::ShaderAst
|
||||
|
||||
void SanitizeVisitor::Validate(CallFunctionExpression& node)
|
||||
{
|
||||
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
||||
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
||||
std::size_t targetFuncIndex;
|
||||
if (node.targetFunction->GetType() == NodeType::FunctionExpression)
|
||||
targetFuncIndex = static_cast<FunctionExpression&>(*node.targetFunction).funcId;
|
||||
else if (node.targetFunction->GetType() == NodeType::AliasValueExpression)
|
||||
{
|
||||
const auto& alias = static_cast<AliasValueExpression&>(*node.targetFunction);
|
||||
|
||||
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId));
|
||||
if (targetIdentifier->category != IdentifierCategory::Function)
|
||||
throw AstError{ "expected function expression" };
|
||||
|
||||
targetFuncIndex = targetIdentifier->index;
|
||||
}
|
||||
else
|
||||
throw AstError{ "expected function expression" };
|
||||
|
||||
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
||||
auto& funcData = m_context->functions.Retrieve(targetFuncIndex);
|
||||
|
||||
const DeclareFunctionStatement* referenceDeclaration = funcData.node;
|
||||
|
||||
@@ -85,7 +85,7 @@ namespace Nz
|
||||
assert(node.funcIndex);
|
||||
assert(functions.find(node.funcIndex.value()) == functions.end());
|
||||
FunctionData& funcData = functions[node.funcIndex.value()];
|
||||
funcData.name = node.name;
|
||||
funcData.name = node.name + moduleSuffix;
|
||||
funcData.node = &node;
|
||||
|
||||
currentFunction = &funcData;
|
||||
@@ -105,6 +105,7 @@ namespace Nz
|
||||
FunctionData* currentFunction = nullptr;
|
||||
|
||||
std::optional<ShaderStageType> selectedStage;
|
||||
std::string moduleSuffix;
|
||||
std::unordered_map<std::size_t, FunctionData> functions;
|
||||
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
|
||||
};
|
||||
@@ -201,8 +202,12 @@ namespace Nz
|
||||
state.previsitor.selectedStage = shaderStage;
|
||||
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
{
|
||||
state.previsitor.moduleSuffix = importedModule.identifier;
|
||||
importedModule.module->rootNode->Visit(state.previsitor);
|
||||
}
|
||||
|
||||
state.previsitor.moduleSuffix = {};
|
||||
targetModule->rootNode->Visit(state.previsitor);
|
||||
|
||||
if (!state.previsitor.entryPoint)
|
||||
@@ -216,7 +221,7 @@ namespace Nz
|
||||
|
||||
for (const auto& importedModule : targetModule->importedModules)
|
||||
{
|
||||
AppendComment("Module " + importedModule.module->metadata->moduleId.ToString());
|
||||
AppendComment("Module " + importedModule.module->metadata->moduleName);
|
||||
AppendLine();
|
||||
|
||||
m_currentState->moduleSuffix = importedModule.identifier;
|
||||
@@ -311,8 +316,7 @@ namespace Nz
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
{
|
||||
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionType.funcIndex).name;
|
||||
Append(targetName);
|
||||
throw std::runtime_error("unexpected FunctionType");
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
|
||||
@@ -475,9 +479,9 @@ namespace Nz
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward)
|
||||
void GlslWriter::AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, const std::string& nameOverride, bool forward)
|
||||
{
|
||||
Append(node.returnType, " ", node.name, "(");
|
||||
Append(node.returnType, " ", nameOverride, "(");
|
||||
|
||||
bool first = true;
|
||||
for (const auto& parameter : node.parameters)
|
||||
@@ -970,8 +974,8 @@ namespace Nz
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::FunctionExpression& node)
|
||||
{
|
||||
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, node.funcId).name;
|
||||
Append(targetName);
|
||||
const auto& funcData = Retrieve(m_currentState->previsitor.functions, node.funcId);
|
||||
Append(funcData.name);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node)
|
||||
@@ -1214,7 +1218,7 @@ namespace Nz
|
||||
hasPredeclaration = true;
|
||||
|
||||
auto& targetFunc = Retrieve(m_currentState->previsitor.functions, i);
|
||||
AppendFunctionDeclaration(*targetFunc.node, true);
|
||||
AppendFunctionDeclaration(*targetFunc.node, targetFunc.name, true);
|
||||
|
||||
m_currentState->declaredFunctions.UnboundedSet(i);
|
||||
}
|
||||
@@ -1232,7 +1236,7 @@ namespace Nz
|
||||
RegisterVariable(*parameter.varIndex, parameter.name);
|
||||
}
|
||||
|
||||
AppendFunctionDeclaration(node);
|
||||
AppendFunctionDeclaration(node, funcData.name);
|
||||
EnterScope();
|
||||
{
|
||||
AppendStatementList(node.statements);
|
||||
|
||||
@@ -194,8 +194,7 @@ namespace Nz
|
||||
|
||||
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
{
|
||||
const std::string& targetName = Retrieve(m_currentState->functions, functionType.funcIndex).name;
|
||||
Append(targetName);
|
||||
throw std::runtime_error("unexpected function type");
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::IdentifierType& identifierType)
|
||||
@@ -971,7 +970,7 @@ namespace Nz
|
||||
|
||||
void LangWriter::Visit(ShaderAst::FunctionExpression& node)
|
||||
{
|
||||
Append(Retrieve(m_currentState->functions, node.funcId).name);
|
||||
AppendIdentifier(m_currentState->functions, node.funcId);
|
||||
}
|
||||
|
||||
void LangWriter::Visit(ShaderAst::IdentifierExpression& node)
|
||||
|
||||
@@ -686,24 +686,28 @@ namespace Nz::ShaderLang
|
||||
|
||||
ShaderAst::ExpressionValue<bool> condition;
|
||||
|
||||
for (auto&& [attributeType, arg] : attributes)
|
||||
for (auto&& [attributeType, attributeParam] : attributes)
|
||||
{
|
||||
switch (attributeType)
|
||||
{
|
||||
case ShaderAst::AttributeType::Cond:
|
||||
HandleUniqueAttribute("cond", condition, std::move(arg));
|
||||
HandleUniqueAttribute("cond", condition, std::move(attributeParam));
|
||||
break;
|
||||
|
||||
case ShaderAst::AttributeType::Entry:
|
||||
HandleUniqueStringAttribute("entry", s_entryPoints, func->entryStage, std::move(arg));
|
||||
HandleUniqueStringAttribute("entry", s_entryPoints, func->entryStage, std::move(attributeParam));
|
||||
break;
|
||||
|
||||
case ShaderAst::AttributeType::Export:
|
||||
HandleUniqueAttribute("export", func->isExported, std::move(attributeParam), true);
|
||||
break;
|
||||
|
||||
case ShaderAst::AttributeType::DepthWrite:
|
||||
HandleUniqueStringAttribute("depth_write", s_depthWriteModes, func->depthWrite, std::move(arg));
|
||||
HandleUniqueStringAttribute("depth_write", s_depthWriteModes, func->depthWrite, std::move(attributeParam));
|
||||
break;
|
||||
|
||||
case ShaderAst::AttributeType::EarlyFragmentTests:
|
||||
HandleUniqueAttribute("early_fragment_tests", func->earlyFragmentTests, std::move(arg), true);
|
||||
HandleUniqueAttribute("early_fragment_tests", func->earlyFragmentTests, std::move(attributeParam), true);
|
||||
break;
|
||||
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user