This commit is contained in:
Jérôme Leclercq 2022-03-07 19:20:14 +01:00
parent d72ac9cc73
commit 012712b8d0
13 changed files with 398 additions and 327 deletions

View File

@ -17,6 +17,7 @@ namespace Nz::ShaderAst
{ {
inline bool Compare(const Expression& lhs, const Expression& rhs); inline bool Compare(const Expression& lhs, const Expression& rhs);
inline bool Compare(const Module& lhs, const Module& rhs); inline bool Compare(const Module& lhs, const Module& rhs);
inline bool Compare(const Module::ImportedModule& lhs, const Module::ImportedModule& rhs);
inline bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs); inline bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs);
inline bool Compare(const Statement& lhs, const Statement& rhs); inline bool Compare(const Statement& lhs, const Statement& rhs);

View File

@ -31,12 +31,29 @@ namespace Nz::ShaderAst
if (!Compare(*lhs.metadata, *rhs.metadata)) if (!Compare(*lhs.metadata, *rhs.metadata))
return false; return false;
if (!Compare(lhs.importedModules, rhs.importedModules))
return false;
if (!Compare(*lhs.rootNode, *rhs.rootNode)) if (!Compare(*lhs.rootNode, *rhs.rootNode))
return false; return false;
return true; return true;
} }
bool Compare(const Module::ImportedModule& lhs, const Module::ImportedModule& rhs)
{
if (!Compare(lhs.identifier, rhs.identifier))
return false;
if (!Compare(lhs.dependencies, rhs.dependencies))
return false;
if (!Compare(*lhs.module, *rhs.module))
return false;
return false;
}
bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs) bool Compare(const Module::Metadata& lhs, const Module::Metadata& rhs)
{ {
if (!Compare(lhs.moduleId, rhs.moduleId)) if (!Compare(lhs.moduleId, rhs.moduleId))

View File

@ -39,6 +39,7 @@ namespace Nz
LangVersion, //< NZSL version - has argument version string LangVersion, //< NZSL version - has argument version string
Set, //< Binding set (external var only) - has argument index Set, //< Binding set (external var only) - has argument index
Unroll, //< Unroll (for/for each only) - has argument mode Unroll, //< Unroll (for/for each only) - has argument mode
Uuid, //< Uuid (module only) - has argument string
}; };
enum class BinaryType enum class BinaryType

View File

@ -24,7 +24,7 @@ namespace Nz::ShaderAst
public: public:
struct Metadata; struct Metadata;
inline Module(UInt32 shaderLangVersion); inline Module(UInt32 shaderLangVersion, const Uuid& moduleId = Uuid::Generate());
inline Module(std::shared_ptr<const Metadata> metadata); inline Module(std::shared_ptr<const Metadata> metadata);
inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode); inline Module(std::shared_ptr<const Metadata> metadata, MultiStatementPtr rootNode);
Module(const Module&) = default; Module(const Module&) = default;
@ -34,14 +34,22 @@ namespace Nz::ShaderAst
Module& operator=(const Module&) = default; Module& operator=(const Module&) = default;
Module& operator=(Module&&) noexcept = default; Module& operator=(Module&&) noexcept = default;
std::shared_ptr<const Metadata> metadata; struct ImportedModule
MultiStatementPtr rootNode; {
std::string identifier;
std::vector<Uuid> dependencies;
ModulePtr module;
};
struct Metadata struct Metadata
{ {
UInt32 shaderLangVersion; UInt32 shaderLangVersion;
Uuid moduleId; Uuid moduleId;
}; };
std::shared_ptr<const Metadata> metadata;
std::vector<ImportedModule> importedModules;
MultiStatementPtr rootNode;
}; };
} }

View File

@ -8,10 +8,10 @@
namespace Nz::ShaderAst namespace Nz::ShaderAst
{ {
inline Module::Module(UInt32 shaderLangVersion) inline Module::Module(UInt32 shaderLangVersion, const Uuid& uuid)
{ {
auto mutMetadata = std::make_shared<Metadata>(); auto mutMetadata = std::make_shared<Metadata>();
mutMetadata->moduleId = Uuid::Generate(); mutMetadata->moduleId = uuid;
mutMetadata->shaderLangVersion = shaderLangVersion; mutMetadata->shaderLangVersion = shaderLangVersion;
metadata = std::move(mutMetadata); metadata = std::move(mutMetadata);

View File

@ -250,7 +250,6 @@ namespace Nz::ShaderAst
void Visit(AstStatementVisitor& visitor) override; void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> constIndex; std::optional<std::size_t> constIndex;
std::optional<bool> hidden;
std::string name; std::string name;
ExpressionPtr expression; ExpressionPtr expression;
ExpressionValue<ExpressionType> type; ExpressionValue<ExpressionType> type;
@ -270,7 +269,6 @@ namespace Nz::ShaderAst
ExpressionValue<ExpressionType> type; ExpressionValue<ExpressionType> type;
}; };
std::optional<bool> hidden;
std::vector<ExternalVar> externalVars; std::vector<ExternalVar> externalVars;
ExpressionValue<UInt32> bindingSet; ExpressionValue<UInt32> bindingSet;
}; };
@ -288,7 +286,6 @@ namespace Nz::ShaderAst
}; };
std::optional<std::size_t> funcIndex; std::optional<std::size_t> funcIndex;
std::optional<bool> hidden;
std::string name; std::string name;
std::vector<Parameter> parameters; std::vector<Parameter> parameters;
std::vector<StatementPtr> statements; std::vector<StatementPtr> statements;
@ -304,7 +301,6 @@ namespace Nz::ShaderAst
void Visit(AstStatementVisitor& visitor) override; void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> optIndex; std::optional<std::size_t> optIndex;
std::optional<bool> hidden;
std::string optName; std::string optName;
ExpressionPtr defaultValue; ExpressionPtr defaultValue;
ExpressionValue<ExpressionType> optType; ExpressionValue<ExpressionType> optType;
@ -316,7 +312,6 @@ namespace Nz::ShaderAst
void Visit(AstStatementVisitor& visitor) override; void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> structIndex; std::optional<std::size_t> structIndex;
std::optional<bool> hidden;
ExpressionValue<bool> isExported; ExpressionValue<bool> isExported;
StructDescription description; StructDescription description;
}; };

View File

@ -57,8 +57,12 @@ namespace Nz::ShaderAst
}; };
private: private:
struct CurrentFunctionData;
struct Environment;
struct FunctionData; struct FunctionData;
struct Identifier; struct Identifier;
template<typename T> struct IdentifierData;
struct Scope;
using AstCloner::CloneExpression; using AstCloner::CloneExpression;
ExpressionValue<ExpressionType> CloneType(const ExpressionValue<ExpressionType>& exprType) override; ExpressionValue<ExpressionType> CloneType(const ExpressionValue<ExpressionType>& exprType) override;
@ -97,6 +101,8 @@ namespace Nz::ShaderAst
const Identifier* FindIdentifier(const std::string_view& identifierName) const; const Identifier* FindIdentifier(const std::string_view& identifierName) const;
template<typename F> const Identifier* FindIdentifier(const std::string_view& identifierName, F&& functor) const; template<typename F> const Identifier* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
const Identifier* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
template<typename F> const Identifier* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const; TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
Expression& MandatoryExpr(const ExpressionPtr& node) const; Expression& MandatoryExpr(const ExpressionPtr& node) const;
@ -114,13 +120,14 @@ namespace Nz::ShaderAst
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen); void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
void RegisterBuiltin(); void RegisterBuiltin();
std::size_t RegisterConstant(std::string name, ConstantValue value, bool hidden = false, std::optional<std::size_t> index = {}); std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, FunctionData funcData, bool hidden = false, std::optional<std::size_t> index = {}); std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type, bool hidden = false, std::optional<std::size_t> index = {}); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterStruct(std::string name, StructDescription* description, bool hidden = false, std::optional<std::size_t> index = {}); std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
std::size_t RegisterType(std::string name, ExpressionType expressionType, bool hidden = false, std::optional<std::size_t> index = {}); std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, PartialType partialType, bool hidden = false, std::optional<std::size_t> index = {}); std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, bool hidden = false, std::optional<std::size_t> index = {}); std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
void ResolveFunctions(); void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
@ -132,6 +139,7 @@ namespace Nz::ShaderAst
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue); ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue);
void SanitizeIdentifier(std::string& identifier); void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const; void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const; void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
@ -167,6 +175,7 @@ namespace Nz::ShaderAst
Constant, Constant,
Function, Function,
Intrinsic, Intrinsic,
Module,
Struct, Struct,
Type, Type,
Variable Variable

View File

@ -81,7 +81,6 @@ namespace Nz::ShaderAst
{ {
auto clone = std::make_unique<DeclareConstStatement>(); auto clone = std::make_unique<DeclareConstStatement>();
clone->constIndex = node.constIndex; clone->constIndex = node.constIndex;
clone->hidden = node.hidden;
clone->name = node.name; clone->name = node.name;
clone->type = Clone(node.type); clone->type = Clone(node.type);
clone->expression = CloneExpression(node.expression); clone->expression = CloneExpression(node.expression);
@ -93,7 +92,6 @@ namespace Nz::ShaderAst
{ {
auto clone = std::make_unique<DeclareExternalStatement>(); auto clone = std::make_unique<DeclareExternalStatement>();
clone->bindingSet = Clone(node.bindingSet); clone->bindingSet = Clone(node.bindingSet);
clone->hidden = node.hidden;
clone->externalVars.reserve(node.externalVars.size()); clone->externalVars.reserve(node.externalVars.size());
for (const auto& var : node.externalVars) for (const auto& var : node.externalVars)
@ -116,7 +114,6 @@ namespace Nz::ShaderAst
clone->earlyFragmentTests = Clone(node.earlyFragmentTests); clone->earlyFragmentTests = Clone(node.earlyFragmentTests);
clone->entryStage = Clone(node.entryStage); clone->entryStage = Clone(node.entryStage);
clone->funcIndex = node.funcIndex; clone->funcIndex = node.funcIndex;
clone->hidden = node.hidden;
clone->name = node.name; clone->name = node.name;
clone->returnType = Clone(node.returnType); clone->returnType = Clone(node.returnType);
@ -140,7 +137,6 @@ namespace Nz::ShaderAst
{ {
auto clone = std::make_unique<DeclareOptionStatement>(); auto clone = std::make_unique<DeclareOptionStatement>();
clone->defaultValue = CloneExpression(node.defaultValue); clone->defaultValue = CloneExpression(node.defaultValue);
clone->hidden = node.hidden;
clone->optIndex = node.optIndex; clone->optIndex = node.optIndex;
clone->optName = node.optName; clone->optName = node.optName;
clone->optType = Clone(node.optType); clone->optType = Clone(node.optType);
@ -151,7 +147,6 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareStructStatement& node) StatementPtr AstCloner::Clone(DeclareStructStatement& node)
{ {
auto clone = std::make_unique<DeclareStructStatement>(); auto clone = std::make_unique<DeclareStructStatement>();
clone->hidden = node.hidden;
clone->isExported = Clone(node.isExported); clone->isExported = Clone(node.isExported);
clone->structIndex = node.structIndex; clone->structIndex = node.structIndex;

View File

@ -1,10 +0,0 @@
// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Engine - Shader module"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/AstCompare.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
}

View File

@ -195,7 +195,6 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareExternalStatement& node) void AstSerializerBase::Serialize(DeclareExternalStatement& node)
{ {
OptVal(node.hidden);
ExprValue(node.bindingSet); ExprValue(node.bindingSet);
Container(node.externalVars); Container(node.externalVars);
@ -212,7 +211,6 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareConstStatement& node) void AstSerializerBase::Serialize(DeclareConstStatement& node)
{ {
OptVal(node.constIndex); OptVal(node.constIndex);
OptVal(node.hidden);
Value(node.name); Value(node.name);
ExprValue(node.type); ExprValue(node.type);
Node(node.expression); Node(node.expression);
@ -226,7 +224,6 @@ namespace Nz::ShaderAst
ExprValue(node.earlyFragmentTests); ExprValue(node.earlyFragmentTests);
ExprValue(node.entryStage); ExprValue(node.entryStage);
OptVal(node.funcIndex); OptVal(node.funcIndex);
OptVal(node.hidden);
Container(node.parameters); Container(node.parameters);
for (auto& parameter : node.parameters) for (auto& parameter : node.parameters)
@ -244,7 +241,6 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareOptionStatement& node) void AstSerializerBase::Serialize(DeclareOptionStatement& node)
{ {
OptVal(node.optIndex); OptVal(node.optIndex);
OptVal(node.hidden);
Value(node.optName); Value(node.optName);
ExprValue(node.optType); ExprValue(node.optType);
Node(node.defaultValue); Node(node.defaultValue);
@ -253,7 +249,6 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(DeclareStructStatement& node) void AstSerializerBase::Serialize(DeclareStructStatement& node)
{ {
OptVal(node.structIndex); OptVal(node.structIndex);
OptVal(node.hidden);
ExprValue(node.isExported); ExprValue(node.isExported);
Value(node.description.name); Value(node.description.name);

View File

@ -37,9 +37,7 @@ namespace Nz::ShaderAst
} }
} }
struct SanitizeVisitor::Context struct SanitizeVisitor::CurrentFunctionData
{
struct CurrentFunctionData
{ {
std::optional<ShaderStageType> stageType; std::optional<ShaderStageType> stageType;
Bitset<> calledFunctions; Bitset<> calledFunctions;
@ -48,7 +46,7 @@ namespace Nz::ShaderAst
}; };
template<typename T> template<typename T>
struct IdentifierData struct SanitizeVisitor::IdentifierData
{ {
Bitset<UInt64> availableIndices; Bitset<UInt64> availableIndices;
Bitset<UInt64> preregisteredIndices; Bitset<UInt64> preregisteredIndices;
@ -109,72 +107,63 @@ namespace Nz::ShaderAst
} }
}; };
struct SanitizeVisitor::Scope
{
std::size_t previousSize;
};
struct SanitizeVisitor::Environment
{
std::shared_ptr<Environment> parentEnv;
std::vector<Identifier> identifiersInScope;
std::vector<Scope> scopes;
IdentifierData<ConstantValue> constantValues;
IdentifierData<FunctionData> functions;
IdentifierData<IntrinsicType> intrinsics;
IdentifierData<std::size_t> moduleIndices;
IdentifierData<StructDescription*> structs;
IdentifierData<std::variant<ExpressionType, PartialType>> types;
IdentifierData<ExpressionType> variableTypes;
};
struct SanitizeVisitor::Context
{
struct PendingFunction struct PendingFunction
{ {
DeclareFunctionStatement* cloneNode; DeclareFunctionStatement* cloneNode;
const DeclareFunctionStatement* node; const DeclareFunctionStatement* node;
}; };
struct Scope
{
std::size_t previousSize;
};
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {}; std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::optional<DependencyCheckerVisitor::UsageSet> importUsage; std::vector<std::shared_ptr<Environment>> moduleEnvironments;
std::vector<Identifier> identifiersInScope;
std::vector<PendingFunction> pendingFunctions; std::vector<PendingFunction> pendingFunctions;
std::vector<Scope> scopes;
std::vector<StatementPtr>* currentStatementList = nullptr; std::vector<StatementPtr>* currentStatementList = nullptr;
std::unordered_map<Uuid, std::size_t> moduleByUuid;
std::unordered_set<std::string> declaredExternalVar; std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<UInt64> usedBindingIndexes; std::unordered_set<UInt64> usedBindingIndexes;
IdentifierData<ConstantValue> constantValues; std::shared_ptr<Environment> globalEnv;
IdentifierData<FunctionData> functions; std::shared_ptr<Environment> currentEnv;
IdentifierData<IntrinsicType> intrinsics;
IdentifierData<StructDescription*> structs;
IdentifierData<std::variant<ExpressionType, PartialType>> types;
IdentifierData<ExpressionType> variableTypes;
Options options; Options options;
CurrentFunctionData* currentFunction = nullptr; CurrentFunctionData* currentFunction = nullptr;
}; };
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
{ {
ModulePtr clone = std::make_shared<Module>(module.metadata);
Context currentContext; Context currentContext;
currentContext.options = options; currentContext.options = options;
ModulePtr clone = std::make_shared<Module>(module.metadata);
clone->importedModules = module.importedModules;
m_context = &currentContext; m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; }); CallOnExit resetContext([&] { m_context = nullptr; });
PushScope(); //< Global scope m_context->globalEnv = std::make_shared<Environment>();
{ m_context->currentEnv = m_context->globalEnv;
RegisterBuiltin();
// First pass, evaluate everything except function code clone->rootNode = SanitizeInternal(*module.rootNode, error);
try if (!clone->rootNode)
{ return {};
clone->rootNode = static_unique_pointer_cast<MultiStatement>(AstCloner::Clone(*module.rootNode));
}
catch (const AstError& err)
{
if (!error)
throw std::runtime_error(err.errMsg);
*error = err.errMsg;
}
catch (const std::runtime_error& err)
{
if (!error)
throw;
*error = err.what();
}
ResolveFunctions();
}
PopScope();
return clone; return clone;
} }
@ -252,7 +241,7 @@ namespace Nz::ShaderAst
else if (IsStructType(exprType)) else if (IsStructType(exprType))
{ {
std::size_t structIndex = ResolveStruct(exprType); std::size_t structIndex = ResolveStruct(exprType);
const StructDescription* s = m_context->structs.Retrieve(structIndex); const StructDescription* s = m_context->currentEnv->structs.Retrieve(structIndex);
// Retrieve member index (not counting disabled fields) // Retrieve member index (not counting disabled fields)
Int32 fieldIndex = 0; Int32 fieldIndex = 0;
@ -580,7 +569,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
{ {
// Replace by constant value // Replace by constant value
auto constant = ShaderBuilder::Constant(m_context->constantValues.Retrieve(node.constantId)); auto constant = ShaderBuilder::Constant(m_context->currentEnv->constantValues.Retrieve(node.constantId));
constant->cachedExpressionType = GetExpressionType(constant->value); constant->cachedExpressionType = GetExpressionType(constant->value);
return constant; return constant;
@ -615,7 +604,7 @@ namespace Nz::ShaderAst
case Identifier::Type::Intrinsic: case Identifier::Type::Intrinsic:
{ {
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifier->index); IntrinsicType intrinsicType = m_context->currentEnv->intrinsics.Retrieve(identifier->index);
auto clone = AstCloner::Clone(node); auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType };
@ -643,7 +632,7 @@ namespace Nz::ShaderAst
{ {
// Replace IdentifierExpression by VariableExpression // Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>(); auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifier->index); varExpr->cachedExpressionType = m_context->currentEnv->variableTypes.Retrieve(identifier->index);
varExpr->variableId = identifier->index; varExpr->variableId = identifier->index;
return varExpr; return varExpr;
@ -794,10 +783,7 @@ namespace Nz::ShaderAst
clone->type = expressionType; clone->type = expressionType;
if (m_context->importUsage.has_value()) clone->constIndex = RegisterConstant(clone->name, value, clone->constIndex);
clone->hidden = true;
clone->constIndex = RegisterConstant(clone->name, value, clone->hidden.value_or(false), clone->constIndex);
if (m_context->options.removeConstDeclaration) if (m_context->options.removeConstDeclaration)
return ShaderBuilder::NoOp(); return ShaderBuilder::NoOp();
@ -811,13 +797,6 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<DeclareExternalStatement>(AstCloner::Clone(node)); auto clone = static_unique_pointer_cast<DeclareExternalStatement>(AstCloner::Clone(node));
if (m_context->importUsage.has_value())
{
// Since unused variables have been removed when importing a module, every variable should be used
assert(!clone->externalVars.empty());
clone->hidden = !m_context->importUsage->usedVariables.UnboundedTest(*clone->externalVars.front().varIndex);
}
UInt32 defaultBlockSet = 0; UInt32 defaultBlockSet = 0;
if (clone->bindingSet.HasValue()) if (clone->bindingSet.HasValue())
defaultBlockSet = ComputeExprValue(clone->bindingSet); defaultBlockSet = ComputeExprValue(clone->bindingSet);
@ -858,7 +837,7 @@ namespace Nz::ShaderAst
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
extVar.type = std::move(resolvedType); extVar.type = std::move(resolvedType);
extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), clone->hidden.value_or(false), extVar.varIndex); extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex);
SanitizeIdentifier(extVar.name); SanitizeIdentifier(extVar.name);
} }
@ -916,12 +895,6 @@ namespace Nz::ShaderAst
} }
} }
if (m_context->importUsage.has_value())
{
assert(clone->funcIndex);
clone->hidden = !m_context->importUsage->usedStructs.UnboundedTest(*clone->funcIndex);
}
// Function content is resolved in a second pass // Function content is resolved in a second pass
auto& pendingFunc = m_context->pendingFunctions.emplace_back(); auto& pendingFunc = m_context->pendingFunctions.emplace_back();
pendingFunc.cloneNode = clone.get(); pendingFunc.cloneNode = clone.get();
@ -936,7 +909,7 @@ namespace Nz::ShaderAst
FunctionData funcData; FunctionData funcData;
funcData.node = clone.get(); //< update function node funcData.node = clone.get(); //< update function node
std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData), clone->hidden.value_or(false), node.funcIndex); std::size_t funcIndex = RegisterFunction(clone->name, std::move(funcData), node.funcIndex);
clone->funcIndex = funcIndex; clone->funcIndex = funcIndex;
SanitizeIdentifier(clone->name); SanitizeIdentifier(clone->name);
@ -962,13 +935,10 @@ namespace Nz::ShaderAst
UInt32 optionHash = CRC32(reinterpret_cast<const UInt8*>(clone->optName.data()), clone->optName.size()); UInt32 optionHash = CRC32(reinterpret_cast<const UInt8*>(clone->optName.data()), clone->optName.size());
if (m_context->importUsage.has_value())
clone->hidden = true;
if (auto optionValueIt = m_context->options.optionValues.find(optionHash); optionValueIt != m_context->options.optionValues.end()) if (auto optionValueIt = m_context->options.optionValues.find(optionHash); optionValueIt != m_context->options.optionValues.end())
clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, clone->hidden.value_or(false), clone->optIndex); clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second, clone->optIndex);
else if (clone->defaultValue) else if (clone->defaultValue)
clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), clone->hidden.value_or(false), clone->optIndex); clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue), clone->optIndex);
else else
throw AstError{ "missing option " + clone->optName + " value (has no default value)" }; throw AstError{ "missing option " + clone->optName + " value (has no default value)" };
@ -988,12 +958,6 @@ namespace Nz::ShaderAst
if (clone->isExported.HasValue()) if (clone->isExported.HasValue())
clone->isExported = ComputeExprValue(clone->isExported); clone->isExported = ComputeExprValue(clone->isExported);
if (m_context->importUsage.has_value())
{
assert(clone->structIndex);
clone->hidden = !m_context->importUsage->usedStructs.UnboundedTest(*clone->structIndex);
}
std::unordered_set<std::string> declaredMembers; std::unordered_set<std::string> declaredMembers;
for (auto& member : clone->description.members) for (auto& member : clone->description.members)
{ {
@ -1023,7 +987,7 @@ namespace Nz::ShaderAst
else if (IsStructType(resolvedType)) else if (IsStructType(resolvedType))
{ {
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex; std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
const StructDescription* desc = m_context->structs.Retrieve(structIndex); const StructDescription* desc = m_context->currentEnv->structs.Retrieve(structIndex);
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
throw AstError{ "inner struct layout mismatch" }; throw AstError{ "inner struct layout mismatch" };
} }
@ -1032,7 +996,7 @@ namespace Nz::ShaderAst
member.type = std::move(resolvedType); member.type = std::move(resolvedType);
} }
clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->hidden.value_or(false), clone->structIndex); clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex);
SanitizeIdentifier(clone->description.name); SanitizeIdentifier(clone->description.name);
@ -1362,9 +1326,6 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(ImportStatement& node) StatementPtr SanitizeVisitor::Clone(ImportStatement& node)
{ {
// Nested import is handled separately
assert(!m_context->importUsage);
if (!m_context->options.moduleCallback) if (!m_context->options.moduleCallback)
return static_unique_pointer_cast<ImportStatement>(AstCloner::Clone(node)); return static_unique_pointer_cast<ImportStatement>(AstCloner::Clone(node));
@ -1392,8 +1353,14 @@ namespace Nz::ShaderAst
targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString(); targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString();
m_context->currentEnv = m_context->moduleEnvironments.emplace_back();
m_context->currentEnv->parentEnv = m_context->globalEnv;
CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; });
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
std::string error; std::string error;
ModulePtr sanitizedModule = ShaderAst::Sanitize(*targetModule, m_context->options, &error); sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error);
if (!sanitizedModule) if (!sanitizedModule)
throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error }; throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error };
@ -1425,10 +1392,10 @@ namespace Nz::ShaderAst
DependencyCheckerVisitor::UsageSet remappedExportedSet; DependencyCheckerVisitor::UsageSet remappedExportedSet;
IndexRemapperVisitor::Callbacks remapCallbacks; IndexRemapperVisitor::Callbacks remapCallbacks;
remapCallbacks.constIndexGenerator = [this](std::size_t previousIndex) { return m_context->constantValues.RegisterNewIndex(true); }; remapCallbacks.constIndexGenerator = [this](std::size_t previousIndex) { return m_context->currentEnv->constantValues.RegisterNewIndex(true); };
remapCallbacks.funcIndexGenerator = [&](std::size_t previousIndex) remapCallbacks.funcIndexGenerator = [&](std::size_t previousIndex)
{ {
std::size_t newIndex = m_context->functions.RegisterNewIndex(true); std::size_t newIndex = m_context->currentEnv->functions.RegisterNewIndex(true);
if (exportedSet.usedFunctions.Test(previousIndex)) if (exportedSet.usedFunctions.Test(previousIndex))
remappedExportedSet.usedFunctions.UnboundedSet(newIndex); remappedExportedSet.usedFunctions.UnboundedSet(newIndex);
@ -1437,7 +1404,7 @@ namespace Nz::ShaderAst
remapCallbacks.structIndexGenerator = [&](std::size_t previousIndex) remapCallbacks.structIndexGenerator = [&](std::size_t previousIndex)
{ {
std::size_t newIndex = m_context->structs.RegisterNewIndex(true); std::size_t newIndex = m_context->currentEnv->structs.RegisterNewIndex(true);
if (exportedSet.usedStructs.Test(previousIndex)) if (exportedSet.usedStructs.Test(previousIndex))
remappedExportedSet.usedStructs.UnboundedSet(newIndex); remappedExportedSet.usedStructs.UnboundedSet(newIndex);
@ -1446,7 +1413,7 @@ namespace Nz::ShaderAst
remapCallbacks.varIndexGenerator = [&](std::size_t previousIndex) remapCallbacks.varIndexGenerator = [&](std::size_t previousIndex)
{ {
std::size_t newIndex = m_context->variableTypes.RegisterNewIndex(true); std::size_t newIndex = m_context->currentEnv->variableTypes.RegisterNewIndex(true);
if (exportedSet.usedVariables.Test(previousIndex)) if (exportedSet.usedVariables.Test(previousIndex))
remappedExportedSet.usedVariables.UnboundedSet(newIndex); remappedExportedSet.usedVariables.UnboundedSet(newIndex);
@ -1456,8 +1423,8 @@ namespace Nz::ShaderAst
statementPtr = RemapIndices(*statementPtr, remapCallbacks); statementPtr = RemapIndices(*statementPtr, remapCallbacks);
// Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor) // Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor)
m_context->importUsage = remappedExportedSet; //m_context->importUsage = remappedExportedSet;
CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); }); //CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); });
return AstCloner::Clone(*statementPtr); return AstCloner::Clone(*statementPtr);
} }
@ -1513,22 +1480,43 @@ namespace Nz::ShaderAst
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier*
{ {
auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); return FindIdentifier(*m_context->currentEnv, identifierName);
if (it == m_context->identifiersInScope.rend())
return nullptr;
return &*it;
} }
template<typename F> template<typename F>
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const Identifier* auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const Identifier*
{ {
auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) return FindIdentifier(*m_context->currentEnv, identifierName, std::forward<F>(functor));
}
auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName) const -> const Identifier*
{
auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
if (it == environment.identifiersInScope.rend())
{
if (environment.parentEnv)
return FindIdentifier(*environment.parentEnv, identifierName);
else
return nullptr;
}
return &*it;
}
template<typename F>
auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const -> const Identifier*
{
auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier)
{ {
return identifier.name == identifierName && functor(identifier); return identifier.name == identifierName && functor(identifier);
}); });
if (it == m_context->identifiersInScope.rend()) if (it == environment.identifiersInScope.rend())
{
if (environment.parentEnv)
return FindIdentifier(*environment.parentEnv, identifierName, std::forward<F>(functor));
else
return nullptr; return nullptr;
}
return &*it; return &*it;
} }
@ -1542,7 +1530,7 @@ namespace Nz::ShaderAst
switch (identifier->type) switch (identifier->type)
{ {
case Identifier::Type::Constant: case Identifier::Type::Constant:
return m_context->constantValues.Retrieve(identifier->index); return m_context->currentEnv->constantValues.Retrieve(identifier->index);
case Identifier::Type::Struct: case Identifier::Type::Struct:
return StructType{ identifier->index }; return StructType{ identifier->index };
@ -1551,7 +1539,7 @@ namespace Nz::ShaderAst
return std::visit([&](auto&& arg) -> TypeParameter return std::visit([&](auto&& arg) -> TypeParameter
{ {
return arg; return arg;
}, m_context->types.Retrieve(identifier->index)); }, m_context->currentEnv->types.Retrieve(identifier->index));
case Identifier::Type::Alias: case Identifier::Type::Alias:
throw std::runtime_error("TODO"); throw std::runtime_error("TODO");
@ -1587,16 +1575,16 @@ namespace Nz::ShaderAst
void SanitizeVisitor::PushScope() void SanitizeVisitor::PushScope()
{ {
auto& scope = m_context->scopes.emplace_back(); auto& scope = m_context->currentEnv->scopes.emplace_back();
scope.previousSize = m_context->identifiersInScope.size(); scope.previousSize = m_context->currentEnv->identifiersInScope.size();
} }
void SanitizeVisitor::PopScope() void SanitizeVisitor::PopScope()
{ {
assert(!m_context->scopes.empty()); assert(!m_context->currentEnv->scopes.empty());
auto& scope = m_context->scopes.back(); auto& scope = m_context->currentEnv->scopes.back();
m_context->identifiersInScope.resize(scope.previousSize); m_context->currentEnv->identifiersInScope.resize(scope.previousSize);
m_context->scopes.pop_back(); m_context->currentEnv->scopes.pop_back();
} }
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression) ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
@ -1664,7 +1652,7 @@ namespace Nz::ShaderAst
AstConstantPropagationVisitor::Options optimizerOptions; AstConstantPropagationVisitor::Options optimizerOptions;
optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue& optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue&
{ {
return m_context->constantValues.Retrieve(constantId); return m_context->currentEnv->constantValues.Retrieve(constantId);
}; };
// Run optimizer on constant value to hopefully retrieve a single constant value // Run optimizer on constant value to hopefully retrieve a single constant value
@ -1673,7 +1661,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
{ {
auto& funcData = m_context->functions.Retrieve(funcIndex); auto& funcData = m_context->currentEnv->functions.Retrieve(funcIndex);
funcData.flags |= flags; funcData.flags |= flags;
for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i))
@ -1844,29 +1832,22 @@ namespace Nz::ShaderAst
RegisterIntrinsic("reflect", IntrinsicType::Reflect); RegisterIntrinsic("reflect", IntrinsicType::Reflect);
} }
std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, bool hidden, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index)
{
std::size_t constantIndex = m_context->constantValues.Register(std::move(value), index);
if (!hidden)
{ {
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
m_context->identifiersInScope.push_back({ std::size_t constantIndex = m_context->currentEnv->constantValues.Register(std::move(value), index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
constantIndex, constantIndex,
Identifier::Type::Constant Identifier::Type::Constant
}); });
}
return constantIndex; return constantIndex;
} }
std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData, bool hidden, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index)
{
std::size_t functionIndex = m_context->functions.Register(std::move(funcData), index);
if (!hidden)
{ {
if (auto* identifier = FindIdentifier(name)) if (auto* identifier = FindIdentifier(name))
{ {
@ -1875,7 +1856,7 @@ namespace Nz::ShaderAst
// Functions cannot be declared twice, except for entry ones if their stages are different // Functions cannot be declared twice, except for entry ones if their stages are different
if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function) if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function)
{ {
auto& otherFunction = m_context->functions.Retrieve(identifier->index); auto& otherFunction = m_context->currentEnv->functions.Retrieve(identifier->index);
if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue())
duplicate = false; duplicate = false;
} }
@ -1884,97 +1865,87 @@ namespace Nz::ShaderAst
throw AstError{ funcData.node->name + " is already used" }; throw AstError{ funcData.node->name + " is already used" };
} }
m_context->identifiersInScope.push_back({ std::size_t functionIndex = m_context->currentEnv->functions.Register(std::move(funcData), index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
functionIndex, functionIndex,
Identifier::Type::Function Identifier::Type::Function
}); });
}
return functionIndex; return functionIndex;
} }
std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type, bool hidden, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterIntrinsic(std::string name, IntrinsicType type)
{
std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type), index);
if (!hidden)
{ {
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
m_context->identifiersInScope.push_back({ std::size_t intrinsicIndex = m_context->currentEnv->intrinsics.Register(std::move(type));
m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
intrinsicIndex, intrinsicIndex,
Identifier::Type::Intrinsic Identifier::Type::Intrinsic
}); });
}
return intrinsicIndex; return intrinsicIndex;
} }
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, bool hidden, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex)
{ {
std::size_t structIndex = m_context->structs.Register(description, index); return std::size_t();
}
if (!hidden) std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index)
{ {
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
m_context->identifiersInScope.push_back({ std::size_t structIndex = m_context->currentEnv->structs.Register(description, index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
structIndex, structIndex,
Identifier::Type::Struct Identifier::Type::Struct
}); });
}
return structIndex; return structIndex;
} }
std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType, bool hidden, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index)
{
std::size_t typeIndex = m_context->types.Register(std::move(expressionType), index);
if (!hidden)
{ {
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
m_context->identifiersInScope.push_back({ std::size_t typeIndex = m_context->currentEnv->types.Register(std::move(expressionType), index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
typeIndex, typeIndex,
Identifier::Type::Type Identifier::Type::Type
}); });
}
return typeIndex; return typeIndex;
} }
std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType, bool hidden, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index)
{
std::size_t typeIndex = m_context->types.Register(std::move(partialType), index);
if (!hidden)
{ {
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
m_context->identifiersInScope.push_back({ std::size_t typeIndex = m_context->currentEnv->types.Register(std::move(partialType), index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
typeIndex, typeIndex,
Identifier::Type::Type Identifier::Type::Type
}); });
}
return typeIndex; return typeIndex;
} }
std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type, bool hidden, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index)
{
std::size_t varIndex = m_context->variableTypes.Register(std::move(type), index);
if (!hidden)
{ {
if (auto* identifier = FindIdentifier(name)) if (auto* identifier = FindIdentifier(name))
{ {
@ -1983,12 +1954,13 @@ namespace Nz::ShaderAst
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
} }
m_context->identifiersInScope.push_back({ std::size_t varIndex = m_context->currentEnv->variableTypes.Register(std::move(type), index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
varIndex, varIndex,
Identifier::Type::Variable Identifier::Type::Variable
}); });
}
return varIndex; return varIndex;
} }
@ -2002,11 +1974,11 @@ namespace Nz::ShaderAst
for (auto& parameter : pendingFunc.cloneNode->parameters) for (auto& parameter : pendingFunc.cloneNode->parameters)
{ {
parameter.varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue(), false, parameter.varIndex); parameter.varIndex = RegisterVariable(parameter.name, parameter.type.GetResultingValue(), parameter.varIndex);
SanitizeIdentifier(parameter.name); SanitizeIdentifier(parameter.name);
} }
Context::CurrentFunctionData tempFuncData; CurrentFunctionData tempFuncData;
if (pendingFunc.cloneNode->entryStage.HasValue()) if (pendingFunc.cloneNode->entryStage.HasValue())
tempFuncData.stageType = pendingFunc.cloneNode->entryStage.GetResultingValue(); tempFuncData.stageType = pendingFunc.cloneNode->entryStage.GetResultingValue();
@ -2025,21 +1997,22 @@ namespace Nz::ShaderAst
std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex; std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex;
for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i))
{ {
auto& targetFunc = m_context->functions.Retrieve(i); auto& targetFunc = m_context->currentEnv->functions.Retrieve(i);
targetFunc.calledByFunctions.UnboundedSet(funcIndex); targetFunc.calledByFunctions.UnboundedSet(funcIndex);
} }
PopScope(); PopScope();
} }
m_context->pendingFunctions.clear();
Bitset<> seen; Bitset<> seen;
for (const auto& [funcIndex, funcData] : m_context->functions.values) for (const auto& [funcIndex, funcData] : m_context->currentEnv->functions.values)
{ {
PropagateFunctionFlags(funcIndex, funcData.flags, seen); PropagateFunctionFlags(funcIndex, funcData.flags, seen);
seen.Clear(); seen.Clear();
} }
for (const auto& [funcIndex, funcData] : m_context->functions.values) for (const auto& [funcIndex, funcData] : m_context->currentEnv->functions.values)
{ {
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
throw AstError{ "discard can only be used in the fragment stage" }; throw AstError{ "discard can only be used in the fragment stage" };
@ -2118,7 +2091,7 @@ namespace Nz::ShaderAst
std::size_t typeIndex = std::get<Type>(exprType).typeIndex; std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
const auto& type = m_context->types.Retrieve(typeIndex); const auto& type = m_context->currentEnv->types.Retrieve(typeIndex);
if (std::holds_alternative<PartialType>(type)) if (std::holds_alternative<PartialType>(type))
throw AstError{ "full type expected" }; throw AstError{ "full type expected" };
@ -2157,6 +2130,41 @@ namespace Nz::ShaderAst
} }
} }
MultiStatementPtr SanitizeVisitor::SanitizeInternal(MultiStatement& rootNode, std::string* error)
{
MultiStatementPtr output;
PushScope(); //< Global scope
{
RegisterBuiltin();
// First pass, evaluate everything except function code
try
{
output = static_unique_pointer_cast<MultiStatement>(AstCloner::Clone(rootNode));
}
catch (const AstError& err)
{
if (!error)
throw std::runtime_error(err.errMsg);
*error = err.errMsg;
}
catch (const std::runtime_error& err)
{
if (!error)
throw;
*error = err.what();
}
ResolveFunctions();
}
PopScope();
return output;
}
void SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const void SanitizeVisitor::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const
{ {
return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right)); return TypeMustMatch(GetExpressionType(*left), GetExpressionType(*right));
@ -2190,7 +2198,7 @@ namespace Nz::ShaderAst
if (IsTypeExpression(exprType)) if (IsTypeExpression(exprType))
{ {
std::size_t typeIndex = std::get<Type>(exprType).typeIndex; std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
const auto& type = m_context->types.Retrieve(typeIndex); const auto& type = m_context->currentEnv->types.Retrieve(typeIndex);
if (!std::holds_alternative<PartialType>(type)) if (!std::holds_alternative<PartialType>(type))
throw std::runtime_error("only partial types can be specialized"); throw std::runtime_error("only partial types can be specialized");
@ -2283,7 +2291,7 @@ namespace Nz::ShaderAst
Int32 index = std::get<Int32>(constantExpr.value); Int32 index = std::get<Int32>(constantExpr.value);
std::size_t structIndex = ResolveStruct(exprType); std::size_t structIndex = ResolveStruct(exprType);
const StructDescription* s = m_context->structs.Retrieve(structIndex); const StructDescription* s = m_context->currentEnv->structs.Retrieve(structIndex);
exprType = ResolveType(s->members[index].type); exprType = ResolveType(s->members[index].type);
} }
@ -2357,7 +2365,7 @@ namespace Nz::ShaderAst
assert(std::holds_alternative<FunctionType>(targetFuncType)); assert(std::holds_alternative<FunctionType>(targetFuncType));
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex; std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
auto& funcData = m_context->functions.Retrieve(targetFuncIndex); auto& funcData = m_context->currentEnv->functions.Retrieve(targetFuncIndex);
const DeclareFunctionStatement* referenceDeclaration = funcData.node; const DeclareFunctionStatement* referenceDeclaration = funcData.node;
@ -2468,7 +2476,7 @@ namespace Nz::ShaderAst
TypeMustMatch(resolvedType, GetExpressionType(*node.initialExpression)); TypeMustMatch(resolvedType, GetExpressionType(*node.initialExpression));
} }
node.varIndex = RegisterVariable(node.varName, resolvedType, false, node.varIndex); node.varIndex = RegisterVariable(node.varName, resolvedType, node.varIndex);
node.varType = std::move(resolvedType); node.varType = std::move(resolvedType);
if (m_context->options.makeVariableNameUnique) if (m_context->options.makeVariableNameUnique)
@ -2720,7 +2728,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(VariableExpression& node) void SanitizeVisitor::Validate(VariableExpression& node)
{ {
node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId); node.cachedExpressionType = m_context->currentEnv->variableTypes.Retrieve(node.variableId);
} }
ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr) ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)

View File

@ -34,6 +34,7 @@ namespace Nz::ShaderLang
{ "nzsl_version", ShaderAst::AttributeType::LangVersion }, { "nzsl_version", ShaderAst::AttributeType::LangVersion },
{ "set", ShaderAst::AttributeType::Set }, { "set", ShaderAst::AttributeType::Set },
{ "unroll", ShaderAst::AttributeType::Unroll }, { "unroll", ShaderAst::AttributeType::Unroll },
{ "uuid", ShaderAst::AttributeType::Uuid },
}; };
std::unordered_map<std::string, ShaderStageType> s_entryPoints = { std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
@ -247,10 +248,8 @@ namespace Nz::ShaderLang
{ {
Expect(Advance(), TokenType::Module); Expect(Advance(), TokenType::Module);
if (m_context->module)
throw DuplicateModule{ "you must set one module statement per file" };
std::optional<UInt32> moduleVersion; std::optional<UInt32> moduleVersion;
std::optional<Uuid> moduleId;
for (auto&& [attributeType, arg] : attributes) for (auto&& [attributeType, arg] : attributes)
{ {
@ -296,6 +295,32 @@ namespace Nz::ShaderLang
break; break;
} }
case ShaderAst::AttributeType::Uuid:
{
if (moduleId.has_value())
throw AttributeError{ "attribute" + std::string("uuid") + " can only be present once" };
if (!arg)
throw AttributeError{ "attribute " + std::string("uuid") + " requires a parameter" };
const ShaderAst::ExpressionPtr& expr = *arg;
if (expr->GetType() != ShaderAst::NodeType::ConstantValueExpression)
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
const std::string& uuidStr = std::get<std::string>(constantValue.value);
Uuid uuid = Uuid::FromString(uuidStr);
if (uuid.IsNull())
throw AttributeError{ "attribute " + std::string("uuid") + " value is not a valid UUID" };
moduleId = uuid;
break;
}
default: default:
throw AttributeError{ "unhandled attribute for module" }; throw AttributeError{ "unhandled attribute for module" };
} }
@ -304,9 +329,34 @@ namespace Nz::ShaderLang
if (!moduleVersion.has_value()) if (!moduleVersion.has_value())
throw AttributeError{ "missing module version" }; throw AttributeError{ "missing module version" };
m_context->module = std::make_shared<ShaderAst::Module>(*moduleVersion); if (!moduleId)
moduleId = Uuid::Generate();
auto module = std::make_shared<ShaderAst::Module>(*moduleVersion, *moduleId);
if (Peek().type == TokenType::Identifier)
{
// Imported module
if (!m_context->module)
throw UnexpectedToken{}; //< "unexpected token before module declaration"
const std::string& identifier = std::get<std::string>(Peek().data);
Consume();
module->rootNode->statements = ParseStatementList();
auto& importedModule = m_context->module->importedModules.emplace_back();
importedModule.module = std::move(module);
}
else
{
Expect(Advance(), TokenType::Semicolon); Expect(Advance(), TokenType::Semicolon);
if (m_context->module)
throw DuplicateModule{ "you must set one module statement per file" };
m_context->module = std::move(module);
}
} }
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue) void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue)

View File

@ -13,6 +13,8 @@ SCENARIO("Uuid", "[CORE][UUID]")
// Testing some invalid cases // Testing some invalid cases
CHECK(Nz::Uuid::FromString("Nazara Engine") == Nz::Uuid()); CHECK(Nz::Uuid::FromString("Nazara Engine") == Nz::Uuid());
CHECK(Nz::Uuid::FromString("1b0e29af_fd4a_43e0_ba4c_e9334183b1f1") == Nz::Uuid());
CHECK(Nz::Uuid::FromString("1b0e29af-fd4a_43e0_ba4c_e93341-3b1f1") == Nz::Uuid());
CHECK(Nz::Uuid::FromString("Zb0e29af-fd4a-43e0-ba4c-e9334183b1f1") == Nz::Uuid()); CHECK(Nz::Uuid::FromString("Zb0e29af-fd4a-43e0-ba4c-e9334183b1f1") == Nz::Uuid());
CHECK(Nz::Uuid::FromString("1b0e29a\t-fd4a-43e0-ba4c-e9334183b1f1") == Nz::Uuid()); CHECK(Nz::Uuid::FromString("1b0e29a\t-fd4a-43e0-ba4c-e9334183b1f1") == Nz::Uuid());
CHECK(Nz::Uuid::FromString("1b0e29af-fd4\v-43e0-ba4c-e9334183b1f1") == Nz::Uuid()); CHECK(Nz::Uuid::FromString("1b0e29af-fd4\v-43e0-ba4c-e9334183b1f1") == Nz::Uuid());