Modules are workings \o/

This commit is contained in:
Jérôme Leclercq
2022-03-08 20:26:02 +01:00
parent 83d26e209e
commit be9bdc4705
29 changed files with 742 additions and 256 deletions

View File

@@ -405,6 +405,16 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(FunctionExpression& node)
{
auto clone = std::make_unique<FunctionExpression>();
clone->funcId = node.funcId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(IdentifierExpression& node)
{
auto clone = std::make_unique<IdentifierExpression>();
@@ -429,6 +439,26 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(IntrinsicFunctionExpression& node)
{
auto clone = std::make_unique<IntrinsicFunctionExpression>();
clone->intrinsicId = node.intrinsicId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(StructTypeExpression& node)
{
auto clone = std::make_unique<StructTypeExpression>();
clone->structTypeId = node.structTypeId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(SwizzleExpression& node)
{
auto clone = std::make_unique<SwizzleExpression>();

View File

@@ -72,6 +72,11 @@ namespace Nz::ShaderAst
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(FunctionExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/)
{
/* Nothing to do */
@@ -83,6 +88,16 @@ namespace Nz::ShaderAst
param->Visit(*this);
}
void AstRecursiveVisitor::Visit(IntrinsicFunctionExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(StructTypeExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{
if (node.expression)

View File

@@ -153,6 +153,21 @@ namespace Nz::ShaderAst
Node(param);
}
void AstSerializerBase::Serialize(IntrinsicFunctionExpression& node)
{
SizeT(node.intrinsicId);
}
void AstSerializerBase::Serialize(StructTypeExpression& node)
{
SizeT(node.structTypeId);
}
void AstSerializerBase::Serialize(FunctionExpression& node)
{
SizeT(node.funcId);
}
void AstSerializerBase::Serialize(SwizzleExpression& node)
{
SizeT(node.componentCount);

View File

@@ -72,6 +72,11 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(FunctionExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
@@ -82,6 +87,16 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(IntrinsicFunctionExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(StructTypeExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{
if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1)

View File

@@ -137,20 +137,22 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(CallFunctionExpression& node)
ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node)
{
CallFunctionExpressionPtr clone = static_unique_pointer_cast<CallFunctionExpression>(AstCloner::Clone(node));
FunctionExpressionPtr clone = static_unique_pointer_cast<FunctionExpression>(AstCloner::Clone(node));
const auto& targetFuncType = GetExpressionType(*node.targetFunction);
if (std::holds_alternative<FunctionType>(targetFuncType))
{
const auto& funcType = std::get<FunctionType>(targetFuncType);
assert(clone->funcId);
clone->funcId = Retrieve(m_context->newFuncIndices, clone->funcId);
FunctionType newFunc;
newFunc.funcIndex = Retrieve(m_context->newFuncIndices, funcType.funcIndex);
clone->cachedExpressionType = ExpressionType{ newFunc }; //< FIXME We should add FunctionExpression like VariableExpression to handle this
}
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(StructTypeExpression& node)
{
StructTypeExpressionPtr clone = static_unique_pointer_cast<StructTypeExpression>(AstCloner::Clone(node));
assert(clone->structTypeId);
clone->structTypeId = Retrieve(m_context->newStructIndices, clone->structTypeId);
return clone;
}

View File

@@ -7,6 +7,7 @@
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Core/Hash/SHA256.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
#include <Nazara/Shader/Ast/AstExportVisitor.hpp>
@@ -114,6 +115,7 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Environment
{
Uuid moduleId;
std::shared_ptr<Environment> parentEnv;
std::vector<Identifier> identifiersInScope;
std::vector<Scope> scopes;
@@ -121,21 +123,31 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Context
{
struct ModuleData
{
std::unordered_map<Uuid, DependencyCheckerVisitor::UsageSet> exportedSetByModule;
std::shared_ptr<Environment> environment;
std::unique_ptr<DependencyCheckerVisitor> dependenciesVisitor;
};
struct PendingFunction
{
DeclareFunctionStatement* cloneNode;
const DeclareFunctionStatement* node;
};
static constexpr std::size_t ModuleIdSentinel = std::numeric_limits<std::size_t>::max();
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::vector<std::shared_ptr<Environment>> moduleEnvironments;
std::vector<ModuleData> modules;
std::vector<PendingFunction> pendingFunctions;
std::vector<StatementPtr>* currentStatementList = nullptr;
std::unordered_map<Uuid, std::size_t> moduleByUuid;
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<UInt64> usedBindingIndexes;
std::shared_ptr<Environment> globalEnv;
std::shared_ptr<Environment> builtinEnv;
std::shared_ptr<Environment> currentEnv;
std::shared_ptr<Environment> globalEnv;
IdentifierList<ConstantValue> constantValues;
IdentifierList<FunctionData> functions;
IdentifierList<IdentifierData> aliases;
@@ -144,22 +156,57 @@ namespace Nz::ShaderAst
IdentifierList<StructDescription*> structs;
IdentifierList<std::variant<ExpressionType, PartialType>> types;
IdentifierList<ExpressionType> variableTypes;
ModulePtr currentModule;
Options options;
CurrentFunctionData* currentFunction = nullptr;
};
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
{
Context currentContext;
currentContext.options = options;
ModulePtr clone = std::make_shared<Module>(module.metadata);
clone->importedModules = module.importedModules;
Context currentContext;
currentContext.options = options;
currentContext.currentModule = clone;
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
// Register builtin env
m_context->builtinEnv = std::make_shared<Environment>();
m_context->currentEnv = m_context->builtinEnv;
RegisterBuiltin();
m_context->globalEnv = std::make_shared<Environment>();
m_context->globalEnv->moduleId = clone->metadata->moduleId;
m_context->globalEnv->parentEnv = m_context->builtinEnv;
for (std::size_t moduleId = 0; moduleId < clone->importedModules.size(); ++moduleId)
{
auto moduleEnv = std::make_shared<Environment>();
moduleEnv->moduleId = clone->importedModules[moduleId].module->metadata->moduleId;
moduleEnv->parentEnv = m_context->builtinEnv;
m_context->currentEnv = moduleEnv;
// Previous modules are visibles
for (std::size_t previousModuleId = 0; previousModuleId < moduleId; ++previousModuleId)
RegisterModule(clone->importedModules[previousModuleId].identifier, previousModuleId);
auto& importedModule = clone->importedModules[moduleId];
importedModule.module->rootNode = SanitizeInternal(*importedModule.module->rootNode, error);
if (!importedModule.module->rootNode)
return {};
m_context->moduleByUuid[importedModule.module->metadata->moduleId] = moduleId;
auto& moduleData = m_context->modules.emplace_back();
moduleData.environment = std::move(moduleEnv);
m_context->currentEnv = m_context->globalEnv;
RegisterModule(importedModule.identifier, moduleId);
}
m_context->currentEnv = m_context->globalEnv;
clone->rootNode = SanitizeInternal(*module.rootNode, error);
@@ -211,7 +258,25 @@ namespace Nz::ShaderAst
if (node.identifiers.empty())
throw AstError{ "AccessIdentifierExpression must have at least one identifier" };
ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr));
MandatoryExpr(node.expr);
// Handle module access (TODO: Add namespace expression?)
if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1)
{
auto& identifierExpr = static_cast<IdentifierExpression&>(*node.expr);
const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier);
if (identifierData && identifierData->category == IdentifierCategory::Module)
{
std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index);
const auto& env = *m_context->modules[moduleIndex].environment;
identifierData = FindIdentifier(env, node.identifiers.front());
if (identifierData)
return HandleIdentifier(identifierData);
}
}
ExpressionPtr indexedExpr = CloneExpression(node.expr);
for (const std::string& identifier : node.identifiers)
{
if (identifier.empty())
@@ -393,7 +458,10 @@ namespace Nz::ShaderAst
if (!m_context->currentFunction)
throw AstError{ "function calls must happen inside a function" };
std::size_t targetFuncIndex = std::get<FunctionType>(targetExprType).funcIndex;
if (targetExpr->GetType() != NodeType::FunctionExpression)
throw AstError{ "expected function expression" };
std::size_t targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
auto clone = std::make_unique<CallFunctionExpression>();
clone->targetFunction = std::move(targetExpr);
@@ -410,13 +478,18 @@ namespace Nz::ShaderAst
}
else if (IsIntrinsicFunctionType(targetExprType))
{
if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression)
throw AstError{ "expected intrinsic function expression" };
std::size_t targetIntrinsicId = static_cast<IntrinsicFunctionExpression&>(*targetExpr).intrinsicId;
std::vector<ExpressionPtr> parameters;
parameters.reserve(node.parameters.size());
for (const auto& param : node.parameters)
parameters.push_back(CloneExpression(param));
auto intrinsic = ShaderBuilder::Intrinsic(std::get<IntrinsicFunctionType>(targetExprType).intrinsic, std::move(parameters));
auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics.Retrieve(targetIntrinsicId), std::move(parameters));
Validate(*intrinsic);
return intrinsic;
@@ -584,64 +657,7 @@ namespace Nz::ShaderAst
if (!identifierData)
throw AstError{ "unknown identifier " + node.identifier };
switch (identifierData->category)
{
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr;
constantExpr.constantId = identifierData->index;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}
case IdentifierCategory::Function:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = FunctionType{ identifierData->index };
return clone;
}
case IdentifierCategory::Intrinsic:
{
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index);
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType };
return clone;
}
case IdentifierCategory::Struct:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = StructType{ identifierData->index };
return clone;
}
case IdentifierCategory::Type:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = Type{ identifierData->index };
return clone;
}
case IdentifierCategory::Variable:
{
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index);
varExpr->variableId = identifierData->index;
return varExpr;
}
default:
throw AstError{ "unexpected identifier" };
}
return HandleIdentifier(identifierData);
}
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
@@ -872,6 +888,8 @@ namespace Nz::ShaderAst
if (node.returnType.HasValue())
clone->returnType = ResolveType(node.returnType);
else
clone->returnType = ExpressionType{ NoType{} };
if (node.depthWrite.HasValue())
clone->depthWrite = ComputeExprValue(node.depthWrite);
@@ -1360,54 +1378,99 @@ namespace Nz::ShaderAst
if (!targetModule)
throw AstError{ "module " + ModulePathAsString() + " not found" };
targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString();
std::size_t moduleIndex;
m_context->currentEnv = m_context->moduleEnvironments.emplace_back(std::make_shared<Environment>());
CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; });
const Uuid& moduleUuid = targetModule->metadata->moduleId;
auto it = m_context->moduleByUuid.find(moduleUuid);
if (it == m_context->moduleByUuid.end())
{
m_context->moduleByUuid[moduleUuid] = Context::ModuleIdSentinel;
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
// Generate module identifier (based on UUID)
const auto& moduleUuidBytes = moduleUuid.ToArray();
std::string error;
sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error);
if (!sanitizedModule)
throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error };
SHA256Hash hasher;
hasher.Begin();
hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size());
hasher.End();
std::string identifier = "__" + hasher.End().ToHex().substr(0, 8);
// Load new module
auto moduleEnvironment = std::make_shared<Environment>();
moduleEnvironment->parentEnv = m_context->builtinEnv;
auto previousEnv = m_context->currentEnv;
m_context->currentEnv = moduleEnvironment;
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
std::string error;
sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error);
if (!sanitizedModule)
throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error };
moduleIndex = m_context->modules.size();
assert(m_context->modules.size() == moduleIndex);
auto& moduleData = m_context->modules.emplace_back();
moduleData.dependenciesVisitor = std::make_unique<DependencyCheckerVisitor>();
moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode);
moduleData.environment = std::move(moduleEnvironment);
assert(m_context->currentModule->importedModules.size() == moduleIndex);
auto& importedModule = m_context->currentModule->importedModules.emplace_back();
importedModule.identifier = identifier;
importedModule.module = std::move(sanitizedModule);
m_context->currentEnv = std::move(previousEnv);
RegisterModule(identifier, moduleIndex);
m_context->moduleByUuid[moduleUuid] = moduleIndex;
}
else
{
// Module has already been imported
moduleIndex = it->second;
if (moduleIndex == Context::ModuleIdSentinel)
throw AstError{ "circular import detected" };
}
auto& moduleData = m_context->modules[moduleIndex];
auto& exportedSet = moduleData.exportedSetByModule[m_context->currentEnv->moduleId];
// Extract exported nodes and their dependencies
DependencyCheckerVisitor::Config depConfig;
depConfig.usedShaderStages.Clear();
DependencyCheckerVisitor moduleDependencies;
moduleDependencies.Process(*sanitizedModule->rootNode, depConfig);
DependencyCheckerVisitor::UsageSet exportedSet;
MultiStatementPtr aliasBlock = std::make_unique<MultiStatement>();
std::vector<DeclareAliasStatementPtr> aliasStatements;
AstExportVisitor::Callbacks callbacks;
callbacks.onExportedStruct = [&](DeclareStructStatement& node)
{
assert(node.structIndex);
moduleDependencies.MarkStructAsUsed(*node.structIndex);
exportedSet.usedStructs.UnboundedSet(*node.structIndex);
auto alias = Clone(node);
// TODO: DeclareAlias
aliasBlock->statements.emplace_back(std::move(alias));
moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex);
if (!exportedSet.usedStructs.UnboundedTest(*node.structIndex))
{
exportedSet.usedStructs.UnboundedSet(*node.structIndex);
aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.description.name, ShaderBuilder::StructType(*node.structIndex)));
}
};
AstExportVisitor exportVisitor;
exportVisitor.Visit(*sanitizedModule->rootNode, callbacks);
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
moduleDependencies.Resolve();
if (aliasStatements.empty())
return ShaderBuilder::NoOp();
//m_context->
// Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor)
//m_context->importUsage = remappedExportedSet;
//CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); });
// Register module and aliases
MultiStatementPtr aliasBlock = std::make_unique<MultiStatement>();
for (auto& aliasPtr : aliasStatements)
{
Validate(*aliasPtr);
aliasBlock->statements.push_back(std::move(aliasPtr));
}
return aliasBlock;
}
@@ -1559,6 +1622,74 @@ namespace Nz::ShaderAst
throw std::runtime_error("internal error");
}
ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData)
{
switch (identifierData->category)
{
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr;
constantExpr.constantId = identifierData->index;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}
case IdentifierCategory::Function:
{
// Replace IdentifierExpression by FunctionExpression
auto funcExpr = std::make_unique<FunctionExpression>();
funcExpr->cachedExpressionType = FunctionType{ identifierData->index }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type
funcExpr->funcId = identifierData->index;
return funcExpr;
}
case IdentifierCategory::Intrinsic:
{
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index);
// Replace IdentifierExpression by IntrinsicFunctionExpression
auto intrinsicExpr = std::make_unique<IntrinsicFunctionExpression>();
intrinsicExpr->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type
intrinsicExpr->intrinsicId = identifierData->index;
return intrinsicExpr;
}
case IdentifierCategory::Struct:
{
// Replace IdentifierExpression by StructTypeExpression
auto structExpr = std::make_unique<StructTypeExpression>();
structExpr->cachedExpressionType = StructType{ identifierData->index };
structExpr->structTypeId = identifierData->index;
return structExpr;
}
case IdentifierCategory::Type:
{
auto clone = ShaderBuilder::Identifier("dummy");
clone->cachedExpressionType = Type{ identifierData->index };
return clone;
}
case IdentifierCategory::Variable:
{
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index);
varExpr->variableId = identifierData->index;
return varExpr;
}
default:
throw AstError{ "unexpected identifier" };
}
}
Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const
{
if (!node)
@@ -1909,20 +2040,20 @@ namespace Nz::ShaderAst
return intrinsicIndex;
}
std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex)
std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t index)
{
if (FindIdentifier(moduleIdentifier))
throw AstError{ moduleIdentifier + " is already used" };
std::size_t intrinsicIndex = m_context->moduleIndices.Register(moduleIndex);
std::size_t moduleIndex = m_context->moduleIndices.Register(index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(moduleIdentifier),
intrinsicIndex,
moduleIndex,
IdentifierCategory::Module
});
return intrinsicIndex;
return moduleIndex;
}
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index)
@@ -2169,11 +2300,7 @@ namespace Nz::ShaderAst
MultiStatementPtr SanitizeVisitor::SanitizeInternal(MultiStatement& rootNode, std::string* error)
{
MultiStatementPtr output;
PushScope(); //< Global scope
{
RegisterBuiltin();
// First pass, evaluate everything except function code
try
{
@@ -2196,7 +2323,6 @@ namespace Nz::ShaderAst
ResolveFunctions();
}
PopScope();
return output;
}
@@ -2261,7 +2387,7 @@ namespace Nz::ShaderAst
StackVector<TypeParameter> parameters = NazaraStackVector(TypeParameter, partialType.parameters.size());
for (std::size_t i = 0; i < partialType.parameters.size(); ++i)
{
ExpressionPtr indexExpr = CloneExpression(node.indices[i]);
const ExpressionPtr& indexExpr = node.indices[i];
switch (partialType.parameters[i])
{
case TypeParameterCategory::ConstantValue: