|
|
|
|
@@ -496,16 +496,28 @@ namespace Nz::ShaderAst
|
|
|
|
|
{
|
|
|
|
|
ExpressionPtr targetExpr = CloneExpression(MandatoryExpr(node.targetFunction));
|
|
|
|
|
const ExpressionType& targetExprType = GetExpressionType(*targetExpr);
|
|
|
|
|
const ExpressionType& resolvedType = ResolveAlias(targetExprType);
|
|
|
|
|
|
|
|
|
|
if (IsFunctionType(targetExprType))
|
|
|
|
|
if (IsFunctionType(resolvedType))
|
|
|
|
|
{
|
|
|
|
|
if (!m_context->currentFunction)
|
|
|
|
|
throw AstError{ "function calls must happen inside a function" };
|
|
|
|
|
|
|
|
|
|
if (targetExpr->GetType() != NodeType::FunctionExpression)
|
|
|
|
|
throw AstError{ "expected function expression" };
|
|
|
|
|
std::size_t targetFuncIndex;
|
|
|
|
|
if (targetExpr->GetType() == NodeType::FunctionExpression)
|
|
|
|
|
targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
|
|
|
|
|
else if (targetExpr->GetType() == NodeType::AliasValueExpression)
|
|
|
|
|
{
|
|
|
|
|
const auto& alias = static_cast<AliasValueExpression&>(*targetExpr);
|
|
|
|
|
|
|
|
|
|
std::size_t targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
|
|
|
|
|
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId));
|
|
|
|
|
if (targetIdentifier->category != IdentifierCategory::Function)
|
|
|
|
|
throw AstError{ "expected function expression" };
|
|
|
|
|
|
|
|
|
|
targetFuncIndex = targetIdentifier->index;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "expected function expression" };
|
|
|
|
|
|
|
|
|
|
auto clone = std::make_unique<CallFunctionExpression>();
|
|
|
|
|
clone->targetFunction = std::move(targetExpr);
|
|
|
|
|
@@ -520,7 +532,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
|
}
|
|
|
|
|
else if (IsIntrinsicFunctionType(targetExprType))
|
|
|
|
|
else if (IsIntrinsicFunctionType(resolvedType))
|
|
|
|
|
{
|
|
|
|
|
if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression)
|
|
|
|
|
throw AstError{ "expected intrinsic function expression" };
|
|
|
|
|
@@ -538,9 +550,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
return intrinsic;
|
|
|
|
|
}
|
|
|
|
|
else if (IsMethodType(targetExprType))
|
|
|
|
|
else if (IsMethodType(resolvedType))
|
|
|
|
|
{
|
|
|
|
|
const MethodType& methodType = std::get<MethodType>(targetExprType);
|
|
|
|
|
const MethodType& methodType = std::get<MethodType>(resolvedType);
|
|
|
|
|
|
|
|
|
|
std::vector<ExpressionPtr> parameters;
|
|
|
|
|
parameters.reserve(node.parameters.size() + 1);
|
|
|
|
|
@@ -987,6 +999,9 @@ namespace Nz::ShaderAst
|
|
|
|
|
if (node.entryStage.HasValue())
|
|
|
|
|
clone->entryStage = ComputeExprValue(node.entryStage);
|
|
|
|
|
|
|
|
|
|
if (node.isExported.HasValue())
|
|
|
|
|
clone->isExported = ComputeExprValue(node.isExported);
|
|
|
|
|
|
|
|
|
|
if (clone->entryStage.HasValue())
|
|
|
|
|
{
|
|
|
|
|
ShaderStageType stageType = clone->entryStage.GetResultingValue();
|
|
|
|
|
@@ -1468,7 +1483,7 @@ namespace Nz::ShaderAst
|
|
|
|
|
hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size());
|
|
|
|
|
hasher.End();
|
|
|
|
|
|
|
|
|
|
std::string identifier = "__" + hasher.End().ToHex().substr(0, 8);
|
|
|
|
|
std::string identifier = "_" + hasher.End().ToHex().substr(0, 8);
|
|
|
|
|
|
|
|
|
|
// Load new module
|
|
|
|
|
auto moduleEnvironment = std::make_shared<Environment>();
|
|
|
|
|
@@ -1519,6 +1534,19 @@ namespace Nz::ShaderAst
|
|
|
|
|
std::vector<DeclareAliasStatementPtr> aliasStatements;
|
|
|
|
|
|
|
|
|
|
AstExportVisitor::Callbacks callbacks;
|
|
|
|
|
callbacks.onExportedFunc = [&](DeclareFunctionStatement& node)
|
|
|
|
|
{
|
|
|
|
|
assert(node.funcIndex);
|
|
|
|
|
|
|
|
|
|
moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex);
|
|
|
|
|
|
|
|
|
|
if (!exportedSet.usedFunctions.UnboundedTest(*node.funcIndex))
|
|
|
|
|
{
|
|
|
|
|
exportedSet.usedFunctions.UnboundedSet(*node.funcIndex);
|
|
|
|
|
aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.name, ShaderBuilder::Function(*node.funcIndex)));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
callbacks.onExportedStruct = [&](DeclareStructStatement& node)
|
|
|
|
|
{
|
|
|
|
|
assert(node.structIndex);
|
|
|
|
|
@@ -2475,13 +2503,18 @@ namespace Nz::ShaderAst
|
|
|
|
|
std::size_t structIndex = ResolveStruct(exprType);
|
|
|
|
|
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
|
|
|
|
|
}
|
|
|
|
|
else if (IsFunctionType(exprType))
|
|
|
|
|
{
|
|
|
|
|
std::size_t funcIndex = std::get<FunctionType>(exprType).funcIndex;
|
|
|
|
|
node.aliasIndex = RegisterAlias(node.name, { funcIndex, IdentifierCategory::Function }, node.aliasIndex);
|
|
|
|
|
}
|
|
|
|
|
else if (IsAliasType(exprType))
|
|
|
|
|
{
|
|
|
|
|
const AliasType& alias = std::get<AliasType>(exprType);
|
|
|
|
|
node.aliasIndex = RegisterAlias(node.name, { alias.aliasIndex, IdentifierCategory::Alias }, node.aliasIndex);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "for now, only structs can be aliased" };
|
|
|
|
|
throw AstError{ "for now, only aliases, functions and structs can be aliased" };
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(WhileStatement& node)
|
|
|
|
|
@@ -2659,10 +2692,22 @@ namespace Nz::ShaderAst
|
|
|
|
|
|
|
|
|
|
void SanitizeVisitor::Validate(CallFunctionExpression& node)
|
|
|
|
|
{
|
|
|
|
|
const ExpressionType& targetFuncType = GetExpressionType(*node.targetFunction);
|
|
|
|
|
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
|
|
|
|
std::size_t targetFuncIndex;
|
|
|
|
|
if (node.targetFunction->GetType() == NodeType::FunctionExpression)
|
|
|
|
|
targetFuncIndex = static_cast<FunctionExpression&>(*node.targetFunction).funcId;
|
|
|
|
|
else if (node.targetFunction->GetType() == NodeType::AliasValueExpression)
|
|
|
|
|
{
|
|
|
|
|
const auto& alias = static_cast<AliasValueExpression&>(*node.targetFunction);
|
|
|
|
|
|
|
|
|
|
const IdentifierData* targetIdentifier = ResolveAliasIdentifier(&m_context->aliases.Retrieve(alias.aliasId));
|
|
|
|
|
if (targetIdentifier->category != IdentifierCategory::Function)
|
|
|
|
|
throw AstError{ "expected function expression" };
|
|
|
|
|
|
|
|
|
|
targetFuncIndex = targetIdentifier->index;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
throw AstError{ "expected function expression" };
|
|
|
|
|
|
|
|
|
|
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
|
|
|
|
|
auto& funcData = m_context->functions.Retrieve(targetFuncIndex);
|
|
|
|
|
|
|
|
|
|
const DeclareFunctionStatement* referenceDeclaration = funcData.node;
|
|
|
|
|
|