This commit is contained in:
Lynix 2022-03-08 01:30:48 +01:00 committed by Jérôme Leclercq
parent 012712b8d0
commit 83d26e209e
22 changed files with 295 additions and 142 deletions

View File

@ -58,6 +58,7 @@ namespace Nz::ShaderAst
virtual StatementPtr Clone(BranchStatement& node); virtual StatementPtr Clone(BranchStatement& node);
virtual StatementPtr Clone(ConditionalStatement& node); virtual StatementPtr Clone(ConditionalStatement& node);
virtual StatementPtr Clone(DeclareAliasStatement& node);
virtual StatementPtr Clone(DeclareConstStatement& node); virtual StatementPtr Clone(DeclareConstStatement& node);
virtual StatementPtr Clone(DeclareExternalStatement& node); virtual StatementPtr Clone(DeclareExternalStatement& node);
virtual StatementPtr Clone(DeclareFunctionStatement& node); virtual StatementPtr Clone(DeclareFunctionStatement& node);

View File

@ -50,6 +50,7 @@ namespace Nz::ShaderAst
inline bool Compare(const BranchStatement& lhs, const BranchStatement& rhs); inline bool Compare(const BranchStatement& lhs, const BranchStatement& rhs);
inline bool Compare(const ConditionalStatement& lhs, const ConditionalStatement& rhs); inline bool Compare(const ConditionalStatement& lhs, const ConditionalStatement& rhs);
inline bool Compare(const DeclareAliasStatement& lhs, const DeclareAliasStatement& rhs);
inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs); inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs);
inline bool Compare(const DeclareExternalStatement& lhs, const DeclareExternalStatement& rhs); inline bool Compare(const DeclareExternalStatement& lhs, const DeclareExternalStatement& rhs);
inline bool Compare(const DeclareFunctionStatement& lhs, const DeclareFunctionStatement& rhs); inline bool Compare(const DeclareFunctionStatement& lhs, const DeclareFunctionStatement& rhs);

View File

@ -419,6 +419,17 @@ namespace Nz::ShaderAst
return true; return true;
} }
bool Compare(const DeclareAliasStatement& lhs, const DeclareAliasStatement& rhs)
{
if (!Compare(lhs.name, rhs.name))
return false;
if (!Compare(lhs.expression, rhs.expression))
return false;
return true;
}
inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs) inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs)
{ {
if (!Compare(lhs.name, rhs.name)) if (!Compare(lhs.name, rhs.name))

View File

@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(VariableExpression)
NAZARA_SHADERAST_EXPRESSION(UnaryExpression) NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement) NAZARA_SHADERAST_STATEMENT(BranchStatement)
NAZARA_SHADERAST_STATEMENT(ConditionalStatement) NAZARA_SHADERAST_STATEMENT(ConditionalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareAliasStatement)
NAZARA_SHADERAST_STATEMENT(DeclareConstStatement) NAZARA_SHADERAST_STATEMENT(DeclareConstStatement)
NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement) NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement) NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement)

View File

@ -38,6 +38,7 @@ namespace Nz::ShaderAst
void Visit(BranchStatement& node) override; void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override; void Visit(ConditionalStatement& node) override;
void Visit(DeclareAliasStatement& node) override;
void Visit(DeclareConstStatement& node) override; void Visit(DeclareConstStatement& node) override;
void Visit(DeclareExternalStatement& node) override; void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareFunctionStatement& node) override;

View File

@ -41,6 +41,7 @@ namespace Nz::ShaderAst
void Serialize(BranchStatement& node); void Serialize(BranchStatement& node);
void Serialize(ConditionalStatement& node); void Serialize(ConditionalStatement& node);
void Serialize(DeclareAliasStatement& node);
void Serialize(DeclareConstStatement& node); void Serialize(DeclareConstStatement& node);
void Serialize(DeclareExternalStatement& node); void Serialize(DeclareExternalStatement& node);
void Serialize(DeclareFunctionStatement& node); void Serialize(DeclareFunctionStatement& node);

View File

@ -244,6 +244,16 @@ namespace Nz::ShaderAst
StatementPtr statement; StatementPtr statement;
}; };
struct NAZARA_SHADER_API DeclareAliasStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> aliasIndex;
std::string name;
ExpressionPtr expression;
};
struct NAZARA_SHADER_API DeclareConstStatement : Statement struct NAZARA_SHADER_API DeclareConstStatement : Statement
{ {
NodeType GetType() const override; NodeType GetType() const override;

View File

@ -57,11 +57,13 @@ namespace Nz::ShaderAst
}; };
private: private:
enum class IdentifierCategory;
struct CurrentFunctionData; struct CurrentFunctionData;
struct Environment; struct Environment;
struct FunctionData; struct FunctionData;
struct Identifier; struct Identifier;
template<typename T> struct IdentifierData; struct IdentifierData;
template<typename T> struct IdentifierList;
struct Scope; struct Scope;
using AstCloner::CloneExpression; using AstCloner::CloneExpression;
@ -84,6 +86,7 @@ namespace Nz::ShaderAst
StatementPtr Clone(BranchStatement& node) override; StatementPtr Clone(BranchStatement& node) override;
StatementPtr Clone(ConditionalStatement& node) override; StatementPtr Clone(ConditionalStatement& node) override;
StatementPtr Clone(DeclareAliasStatement& node) override;
StatementPtr Clone(DeclareConstStatement& node) override; StatementPtr Clone(DeclareConstStatement& node) override;
StatementPtr Clone(DeclareExternalStatement& node) override; StatementPtr Clone(DeclareExternalStatement& node) override;
StatementPtr Clone(DeclareFunctionStatement& node) override; StatementPtr Clone(DeclareFunctionStatement& node) override;
@ -99,10 +102,10 @@ namespace Nz::ShaderAst
StatementPtr Clone(ScopedStatement& node) override; StatementPtr Clone(ScopedStatement& node) override;
StatementPtr Clone(WhileStatement& node) override; StatementPtr Clone(WhileStatement& node) override;
const Identifier* FindIdentifier(const std::string_view& identifierName) const; const IdentifierData* FindIdentifier(const std::string_view& identifierName) const;
template<typename F> const Identifier* FindIdentifier(const std::string_view& identifierName, F&& functor) const; template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
const Identifier* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const; const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
template<typename F> const Identifier* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const; TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
Expression& MandatoryExpr(const ExpressionPtr& node) const; Expression& MandatoryExpr(const ExpressionPtr& node) const;
@ -120,6 +123,8 @@ namespace Nz::ShaderAst
void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen); void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen);
void RegisterBuiltin(); void RegisterBuiltin();
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {}); std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {}); std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
@ -129,6 +134,7 @@ namespace Nz::ShaderAst
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {}); std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {}); std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAlias(const IdentifierData* identifier) const;
void ResolveFunctions(); void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node); const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const ExpressionType& exprType);
@ -146,6 +152,7 @@ namespace Nz::ShaderAst
StatementPtr Unscope(StatementPtr node); StatementPtr Unscope(StatementPtr node);
void Validate(DeclareAliasStatement& node);
void Validate(WhileStatement& node); void Validate(WhileStatement& node);
void Validate(AccessIndexExpression& node); void Validate(AccessIndexExpression& node);
@ -160,6 +167,18 @@ namespace Nz::ShaderAst
void Validate(VariableExpression& node); void Validate(VariableExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr); ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr);
enum class IdentifierCategory
{
Alias,
Constant,
Function,
Intrinsic,
Module,
Struct,
Type,
Variable
};
struct FunctionData struct FunctionData
{ {
Bitset<> calledByFunctions; Bitset<> calledByFunctions;
@ -167,23 +186,16 @@ namespace Nz::ShaderAst
FunctionFlags flags; FunctionFlags flags;
}; };
struct IdentifierData
{
std::size_t index;
IdentifierCategory category;
};
struct Identifier struct Identifier
{ {
enum class Type
{
Alias,
Constant,
Function,
Intrinsic,
Module,
Struct,
Type,
Variable
};
std::string name; std::string name;
std::size_t index; IdentifierData data;
Type type;
}; };
struct Context; struct Context;

View File

@ -105,6 +105,7 @@ namespace Nz
void Visit(ShaderAst::UnaryExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override;
void Visit(ShaderAst::BranchStatement& node) override; void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::DeclareAliasStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override;

View File

@ -110,6 +110,7 @@ namespace Nz
void Visit(ShaderAst::BranchStatement& node) override; void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::ConditionalStatement& node) override; void Visit(ShaderAst::ConditionalStatement& node) override;
void Visit(ShaderAst::DeclareAliasStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override;

View File

@ -76,6 +76,11 @@ namespace Nz::ShaderBuilder
template<typename T> std::unique_ptr<ShaderAst::ConstantValueExpression> operator()(ShaderAst::ExpressionType type, T value) const; template<typename T> std::unique_ptr<ShaderAst::ConstantValueExpression> operator()(ShaderAst::ExpressionType type, T value) const;
}; };
struct DeclareAlias
{
inline std::unique_ptr<ShaderAst::DeclareAliasStatement> operator()(std::string name, ShaderAst::ExpressionPtr expression) const;
};
struct DeclareConst struct DeclareConst
{ {
inline std::unique_ptr<ShaderAst::DeclareConstStatement> operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const; inline std::unique_ptr<ShaderAst::DeclareConstStatement> operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const;
@ -191,6 +196,7 @@ namespace Nz::ShaderBuilder
constexpr Impl::ConditionalStatement ConditionalStatement; constexpr Impl::ConditionalStatement ConditionalStatement;
constexpr Impl::Constant Constant; constexpr Impl::Constant Constant;
constexpr Impl::Branch<true> ConstBranch; constexpr Impl::Branch<true> ConstBranch;
constexpr Impl::DeclareAlias DeclareAlias;
constexpr Impl::DeclareConst DeclareConst; constexpr Impl::DeclareConst DeclareConst;
constexpr Impl::DeclareFunction DeclareFunction; constexpr Impl::DeclareFunction DeclareFunction;
constexpr Impl::DeclareOption DeclareOption; constexpr Impl::DeclareOption DeclareOption;

View File

@ -195,6 +195,15 @@ namespace Nz::ShaderBuilder
throw std::runtime_error("unexpected primitive type"); throw std::runtime_error("unexpected primitive type");
} }
inline std::unique_ptr<ShaderAst::DeclareAliasStatement> Impl::DeclareAlias::operator()(std::string name, ShaderAst::ExpressionPtr expression) const
{
auto declareAliasNode = std::make_unique<ShaderAst::DeclareAliasStatement>();
declareAliasNode->name = std::move(name);
declareAliasNode->expression = std::move(expression);
return declareAliasNode;
}
inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const
{ {
auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>(); auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>();

View File

@ -48,6 +48,7 @@ namespace Nz
void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CallFunctionExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConstantValueExpression& node) override; void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::DeclareAliasStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override; void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override;

View File

@ -77,6 +77,16 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
StatementPtr AstCloner::Clone(DeclareAliasStatement& node)
{
auto clone = std::make_unique<DeclareAliasStatement>();
clone->aliasIndex = node.aliasIndex;
clone->name = node.name;
clone->expression = CloneExpression(node.expression);
return clone;
}
StatementPtr AstCloner::Clone(DeclareConstStatement& node) StatementPtr AstCloner::Clone(DeclareConstStatement& node)
{ {
auto clone = std::make_unique<DeclareConstStatement>(); auto clone = std::make_unique<DeclareConstStatement>();

View File

@ -117,6 +117,12 @@ namespace Nz::ShaderAst
node.statement->Visit(*this); node.statement->Visit(*this);
} }
void AstRecursiveVisitor::Visit(DeclareAliasStatement& node)
{
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareConstStatement& node) void AstRecursiveVisitor::Visit(DeclareConstStatement& node)
{ {
if (node.expression) if (node.expression)

View File

@ -193,6 +193,13 @@ namespace Nz::ShaderAst
Node(node.statement); Node(node.statement);
} }
void AstSerializerBase::Serialize(DeclareAliasStatement& node)
{
OptVal(node.aliasIndex);
Value(node.name);
Node(node.expression);
}
void AstSerializerBase::Serialize(DeclareExternalStatement& node) void AstSerializerBase::Serialize(DeclareExternalStatement& node)
{ {
ExprValue(node.bindingSet); ExprValue(node.bindingSet);

View File

@ -46,7 +46,7 @@ namespace Nz::ShaderAst
}; };
template<typename T> template<typename T>
struct SanitizeVisitor::IdentifierData struct SanitizeVisitor::IdentifierList
{ {
Bitset<UInt64> availableIndices; Bitset<UInt64> availableIndices;
Bitset<UInt64> preregisteredIndices; Bitset<UInt64> preregisteredIndices;
@ -117,13 +117,6 @@ namespace Nz::ShaderAst
std::shared_ptr<Environment> parentEnv; std::shared_ptr<Environment> parentEnv;
std::vector<Identifier> identifiersInScope; std::vector<Identifier> identifiersInScope;
std::vector<Scope> scopes; std::vector<Scope> scopes;
IdentifierData<ConstantValue> constantValues;
IdentifierData<FunctionData> functions;
IdentifierData<IntrinsicType> intrinsics;
IdentifierData<std::size_t> moduleIndices;
IdentifierData<StructDescription*> structs;
IdentifierData<std::variant<ExpressionType, PartialType>> types;
IdentifierData<ExpressionType> variableTypes;
}; };
struct SanitizeVisitor::Context struct SanitizeVisitor::Context
@ -143,6 +136,14 @@ namespace Nz::ShaderAst
std::unordered_set<UInt64> usedBindingIndexes; std::unordered_set<UInt64> usedBindingIndexes;
std::shared_ptr<Environment> globalEnv; std::shared_ptr<Environment> globalEnv;
std::shared_ptr<Environment> currentEnv; std::shared_ptr<Environment> currentEnv;
IdentifierList<ConstantValue> constantValues;
IdentifierList<FunctionData> functions;
IdentifierList<IdentifierData> aliases;
IdentifierList<IntrinsicType> intrinsics;
IdentifierList<std::size_t> moduleIndices;
IdentifierList<StructDescription*> structs;
IdentifierList<std::variant<ExpressionType, PartialType>> types;
IdentifierList<ExpressionType> variableTypes;
Options options; Options options;
CurrentFunctionData* currentFunction = nullptr; CurrentFunctionData* currentFunction = nullptr;
}; };
@ -241,7 +242,7 @@ namespace Nz::ShaderAst
else if (IsStructType(exprType)) else if (IsStructType(exprType))
{ {
std::size_t structIndex = ResolveStruct(exprType); std::size_t structIndex = ResolveStruct(exprType);
const StructDescription* s = m_context->currentEnv->structs.Retrieve(structIndex); const StructDescription* s = m_context->structs.Retrieve(structIndex);
// Retrieve member index (not counting disabled fields) // Retrieve member index (not counting disabled fields)
Int32 fieldIndex = 0; Int32 fieldIndex = 0;
@ -569,7 +570,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node)
{ {
// Replace by constant value // Replace by constant value
auto constant = ShaderBuilder::Constant(m_context->currentEnv->constantValues.Retrieve(node.constantId)); auto constant = ShaderBuilder::Constant(m_context->constantValues.Retrieve(node.constantId));
constant->cachedExpressionType = GetExpressionType(constant->value); constant->cachedExpressionType = GetExpressionType(constant->value);
return constant; return constant;
@ -579,32 +580,32 @@ namespace Nz::ShaderAst
{ {
assert(m_context); assert(m_context);
const Identifier* identifier = FindIdentifier(node.identifier); const IdentifierData* identifierData = FindIdentifier(node.identifier);
if (!identifier) if (!identifierData)
throw AstError{ "unknown identifier " + node.identifier }; throw AstError{ "unknown identifier " + node.identifier };
switch (identifier->type) switch (identifierData->category)
{ {
case Identifier::Type::Constant: case IdentifierCategory::Constant:
{ {
// Replace IdentifierExpression by Constant(Value)Expression // Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr; ConstantExpression constantExpr;
constantExpr.constantId = identifier->index; constantExpr.constantId = identifierData->index;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
} }
case Identifier::Type::Function: case IdentifierCategory::Function:
{ {
auto clone = AstCloner::Clone(node); auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = FunctionType{ identifier->index }; clone->cachedExpressionType = FunctionType{ identifierData->index };
return clone; return clone;
} }
case Identifier::Type::Intrinsic: case IdentifierCategory::Intrinsic:
{ {
IntrinsicType intrinsicType = m_context->currentEnv->intrinsics.Retrieve(identifier->index); IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index);
auto clone = AstCloner::Clone(node); auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType };
@ -612,28 +613,28 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
case Identifier::Type::Struct: case IdentifierCategory::Struct:
{ {
auto clone = AstCloner::Clone(node); auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = StructType{ identifier->index }; clone->cachedExpressionType = StructType{ identifierData->index };
return clone; return clone;
} }
case Identifier::Type::Type: case IdentifierCategory::Type:
{ {
auto clone = AstCloner::Clone(node); auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = Type{ identifier->index }; clone->cachedExpressionType = Type{ identifierData->index };
return clone; return clone;
} }
case Identifier::Type::Variable: case IdentifierCategory::Variable:
{ {
// Replace IdentifierExpression by VariableExpression // Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>(); auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->currentEnv->variableTypes.Retrieve(identifier->index); varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index);
varExpr->variableId = identifier->index; varExpr->variableId = identifierData->index;
return varExpr; return varExpr;
} }
@ -763,6 +764,14 @@ namespace Nz::ShaderAst
return ShaderBuilder::NoOp(); return ShaderBuilder::NoOp();
} }
StatementPtr SanitizeVisitor::Clone(DeclareAliasStatement& node)
{
auto clone = static_unique_pointer_cast<DeclareAliasStatement>(AstCloner::Clone(node));
Validate(*clone);
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node) StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node)
{ {
auto clone = static_unique_pointer_cast<DeclareConstStatement>(AstCloner::Clone(node)); auto clone = static_unique_pointer_cast<DeclareConstStatement>(AstCloner::Clone(node));
@ -987,7 +996,7 @@ namespace Nz::ShaderAst
else if (IsStructType(resolvedType)) else if (IsStructType(resolvedType))
{ {
std::size_t structIndex = std::get<StructType>(resolvedType).structIndex; std::size_t structIndex = std::get<StructType>(resolvedType).structIndex;
const StructDescription* desc = m_context->currentEnv->structs.Retrieve(structIndex); const StructDescription* desc = m_context->structs.Retrieve(structIndex);
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
throw AstError{ "inner struct layout mismatch" }; throw AstError{ "inner struct layout mismatch" };
} }
@ -1353,8 +1362,7 @@ namespace Nz::ShaderAst
targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString(); targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString();
m_context->currentEnv = m_context->moduleEnvironments.emplace_back(); m_context->currentEnv = m_context->moduleEnvironments.emplace_back(std::make_shared<Environment>());
m_context->currentEnv->parentEnv = m_context->globalEnv;
CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; }); CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; });
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata); ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
@ -1373,6 +1381,8 @@ namespace Nz::ShaderAst
DependencyCheckerVisitor::UsageSet exportedSet; DependencyCheckerVisitor::UsageSet exportedSet;
MultiStatementPtr aliasBlock = std::make_unique<MultiStatement>();
AstExportVisitor::Callbacks callbacks; AstExportVisitor::Callbacks callbacks;
callbacks.onExportedStruct = [&](DeclareStructStatement& node) callbacks.onExportedStruct = [&](DeclareStructStatement& node)
{ {
@ -1380,6 +1390,12 @@ namespace Nz::ShaderAst
moduleDependencies.MarkStructAsUsed(*node.structIndex); moduleDependencies.MarkStructAsUsed(*node.structIndex);
exportedSet.usedStructs.UnboundedSet(*node.structIndex); exportedSet.usedStructs.UnboundedSet(*node.structIndex);
auto alias = Clone(node);
// TODO: DeclareAlias
aliasBlock->statements.emplace_back(std::move(alias));
}; };
AstExportVisitor exportVisitor; AstExportVisitor exportVisitor;
@ -1387,46 +1403,13 @@ namespace Nz::ShaderAst
moduleDependencies.Resolve(); moduleDependencies.Resolve();
auto statementPtr = EliminateUnusedPass(*sanitizedModule->rootNode, moduleDependencies.GetUsage()); //m_context->
DependencyCheckerVisitor::UsageSet remappedExportedSet;
IndexRemapperVisitor::Callbacks remapCallbacks;
remapCallbacks.constIndexGenerator = [this](std::size_t previousIndex) { return m_context->currentEnv->constantValues.RegisterNewIndex(true); };
remapCallbacks.funcIndexGenerator = [&](std::size_t previousIndex)
{
std::size_t newIndex = m_context->currentEnv->functions.RegisterNewIndex(true);
if (exportedSet.usedFunctions.Test(previousIndex))
remappedExportedSet.usedFunctions.UnboundedSet(newIndex);
return newIndex;
};
remapCallbacks.structIndexGenerator = [&](std::size_t previousIndex)
{
std::size_t newIndex = m_context->currentEnv->structs.RegisterNewIndex(true);
if (exportedSet.usedStructs.Test(previousIndex))
remappedExportedSet.usedStructs.UnboundedSet(newIndex);
return newIndex;
};
remapCallbacks.varIndexGenerator = [&](std::size_t previousIndex)
{
std::size_t newIndex = m_context->currentEnv->variableTypes.RegisterNewIndex(true);
if (exportedSet.usedVariables.Test(previousIndex))
remappedExportedSet.usedVariables.UnboundedSet(newIndex);
return newIndex;
};
statementPtr = RemapIndices(*statementPtr, remapCallbacks);
// Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor) // Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor)
//m_context->importUsage = remappedExportedSet; //m_context->importUsage = remappedExportedSet;
//CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); }); //CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); });
return AstCloner::Clone(*statementPtr); return aliasBlock;
} }
StatementPtr SanitizeVisitor::Clone(MultiStatement& node) StatementPtr SanitizeVisitor::Clone(MultiStatement& node)
@ -1478,18 +1461,18 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const IdentifierData*
{ {
return FindIdentifier(*m_context->currentEnv, identifierName); return FindIdentifier(*m_context->currentEnv, identifierName);
} }
template<typename F> template<typename F>
auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const Identifier* auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName, F&& functor) const -> const IdentifierData*
{ {
return FindIdentifier(*m_context->currentEnv, identifierName, std::forward<F>(functor)); return FindIdentifier(*m_context->currentEnv, identifierName, std::forward<F>(functor));
} }
auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName) const -> const Identifier* auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName) const -> const IdentifierData*
{ {
auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; });
if (it == environment.identifiersInScope.rend()) if (it == environment.identifiersInScope.rend())
@ -1500,15 +1483,21 @@ namespace Nz::ShaderAst
return nullptr; return nullptr;
} }
return &*it; return ResolveAlias(&it->data);
} }
template<typename F> template<typename F>
auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const -> const Identifier* auto SanitizeVisitor::FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const -> const IdentifierData*
{ {
auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier) auto it = std::find_if(environment.identifiersInScope.rbegin(), environment.identifiersInScope.rend(), [&](const Identifier& identifier)
{ {
return identifier.name == identifierName && functor(identifier); if (identifier.name == identifierName)
{
if (functor(*ResolveAlias(&identifier.data)))
return true;
}
return false;
}); });
if (it == environment.identifiersInScope.rend()) if (it == environment.identifiersInScope.rend())
{ {
@ -1518,7 +1507,7 @@ namespace Nz::ShaderAst
return nullptr; return nullptr;
} }
return &*it; return ResolveAlias(&it->data);
} }
TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const TypeParameter SanitizeVisitor::FindTypeParameter(const std::string_view& identifierName) const
@ -1527,30 +1516,43 @@ namespace Nz::ShaderAst
if (!identifier) if (!identifier)
throw std::runtime_error("identifier " + std::string(identifierName) + " not found"); throw std::runtime_error("identifier " + std::string(identifierName) + " not found");
switch (identifier->type) switch (identifier->category)
{ {
case Identifier::Type::Constant: case IdentifierCategory::Constant:
return m_context->currentEnv->constantValues.Retrieve(identifier->index); return m_context->constantValues.Retrieve(identifier->index);
case Identifier::Type::Struct: case IdentifierCategory::Struct:
return StructType{ identifier->index }; return StructType{ identifier->index };
case Identifier::Type::Type: case IdentifierCategory::Type:
return std::visit([&](auto&& arg) -> TypeParameter return std::visit([&](auto&& arg) -> TypeParameter
{ {
return arg; return arg;
}, m_context->currentEnv->types.Retrieve(identifier->index)); }, m_context->types.Retrieve(identifier->index));
case Identifier::Type::Alias: case IdentifierCategory::Alias:
throw std::runtime_error("TODO"); {
IdentifierCategory category;
std::size_t index;
do
{
const auto& aliasData = m_context->aliases.Retrieve(identifier->index);
category = aliasData.category;
index = aliasData.index;
}
while (category == IdentifierCategory::Alias);
}
case Identifier::Type::Function: case IdentifierCategory::Function:
throw std::runtime_error("unexpected function identifier"); throw std::runtime_error("unexpected function identifier");
case Identifier::Type::Intrinsic: case IdentifierCategory::Intrinsic:
throw std::runtime_error("unexpected intrinsic identifier"); throw std::runtime_error("unexpected intrinsic identifier");
case Identifier::Type::Variable: case IdentifierCategory::Module:
throw std::runtime_error("unexpected module identifier");
case IdentifierCategory::Variable:
throw std::runtime_error("unexpected variable identifier"); throw std::runtime_error("unexpected variable identifier");
} }
@ -1589,7 +1591,7 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression) ExpressionPtr SanitizeVisitor::CacheResult(ExpressionPtr expression)
{ {
// No need to cache LValues (variables/constants) (TODO: Improve this, as constants doens't need to be cached as well) // No need to cache LValues (variables/constants) (TODO: Improve this, as constants doesn't need to be cached as well)
if (GetExpressionCategory(*expression) == ExpressionCategory::LValue) if (GetExpressionCategory(*expression) == ExpressionCategory::LValue)
return expression; return expression;
@ -1652,7 +1654,7 @@ namespace Nz::ShaderAst
AstConstantPropagationVisitor::Options optimizerOptions; AstConstantPropagationVisitor::Options optimizerOptions;
optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue& optimizerOptions.constantQueryCallback = [this](std::size_t constantId) -> const ConstantValue&
{ {
return m_context->currentEnv->constantValues.Retrieve(constantId); return m_context->constantValues.Retrieve(constantId);
}; };
// Run optimizer on constant value to hopefully retrieve a single constant value // Run optimizer on constant value to hopefully retrieve a single constant value
@ -1661,7 +1663,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen)
{ {
auto& funcData = m_context->currentEnv->functions.Retrieve(funcIndex); auto& funcData = m_context->functions.Retrieve(funcIndex);
funcData.flags |= flags; funcData.flags |= flags;
for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i))
@ -1832,16 +1834,31 @@ namespace Nz::ShaderAst
RegisterIntrinsic("reflect", IntrinsicType::Reflect); RegisterIntrinsic("reflect", IntrinsicType::Reflect);
} }
std::size_t SanitizeVisitor::RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index)
{
if (FindIdentifier(name))
throw AstError{ name + " is already used" };
std::size_t aliasIndex = m_context->aliases.Register(std::move(aliasData), index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(name),
aliasIndex,
IdentifierCategory::Alias
});
return aliasIndex;
}
std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index)
{ {
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
std::size_t constantIndex = m_context->currentEnv->constantValues.Register(std::move(value), index); std::size_t constantIndex = m_context->constantValues.Register(std::move(value), index);
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
constantIndex, constantIndex,
Identifier::Type::Constant IdentifierCategory::Constant
}); });
return constantIndex; return constantIndex;
@ -1854,23 +1871,23 @@ namespace Nz::ShaderAst
bool duplicate = true; bool duplicate = true;
// Functions cannot be declared twice, except for entry ones if their stages are different // Functions cannot be declared twice, except for entry ones if their stages are different
if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function) if (funcData.node->entryStage.HasValue() && identifier->category == IdentifierCategory::Function)
{ {
auto& otherFunction = m_context->currentEnv->functions.Retrieve(identifier->index); auto& otherFunction = m_context->functions.Retrieve(identifier->index);
if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue())
duplicate = false; duplicate = false;
} }
if (duplicate) if (duplicate)
throw AstError{ funcData.node->name + " is already used" }; throw AstError{ name + " is already used" };
} }
std::size_t functionIndex = m_context->currentEnv->functions.Register(std::move(funcData), index); std::size_t functionIndex = m_context->functions.Register(std::move(funcData), index);
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
functionIndex, functionIndex,
Identifier::Type::Function IdentifierCategory::Function
}); });
return functionIndex; return functionIndex;
@ -1881,12 +1898,12 @@ namespace Nz::ShaderAst
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
std::size_t intrinsicIndex = m_context->currentEnv->intrinsics.Register(std::move(type)); std::size_t intrinsicIndex = m_context->intrinsics.Register(std::move(type));
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
intrinsicIndex, intrinsicIndex,
Identifier::Type::Intrinsic IdentifierCategory::Intrinsic
}); });
return intrinsicIndex; return intrinsicIndex;
@ -1894,7 +1911,18 @@ namespace Nz::ShaderAst
std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex) std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex)
{ {
return std::size_t(); if (FindIdentifier(moduleIdentifier))
throw AstError{ moduleIdentifier + " is already used" };
std::size_t intrinsicIndex = m_context->moduleIndices.Register(moduleIndex);
m_context->currentEnv->identifiersInScope.push_back({
std::move(moduleIdentifier),
intrinsicIndex,
IdentifierCategory::Module
});
return intrinsicIndex;
} }
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index)
@ -1902,12 +1930,12 @@ namespace Nz::ShaderAst
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
std::size_t structIndex = m_context->currentEnv->structs.Register(description, index); std::size_t structIndex = m_context->structs.Register(description, index);
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
structIndex, structIndex,
Identifier::Type::Struct IdentifierCategory::Struct
}); });
return structIndex; return structIndex;
@ -1918,12 +1946,12 @@ namespace Nz::ShaderAst
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
std::size_t typeIndex = m_context->currentEnv->types.Register(std::move(expressionType), index); std::size_t typeIndex = m_context->types.Register(std::move(expressionType), index);
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
typeIndex, typeIndex,
Identifier::Type::Type IdentifierCategory::Type
}); });
return typeIndex; return typeIndex;
@ -1934,12 +1962,12 @@ namespace Nz::ShaderAst
if (FindIdentifier(name)) if (FindIdentifier(name))
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
std::size_t typeIndex = m_context->currentEnv->types.Register(std::move(partialType), index); std::size_t typeIndex = m_context->types.Register(std::move(partialType), index);
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
typeIndex, typeIndex,
Identifier::Type::Type IdentifierCategory::Type
}); });
return typeIndex; return typeIndex;
@ -1950,21 +1978,29 @@ namespace Nz::ShaderAst
if (auto* identifier = FindIdentifier(name)) if (auto* identifier = FindIdentifier(name))
{ {
// Allow variable shadowing // Allow variable shadowing
if (identifier->type != Identifier::Type::Variable) if (identifier->category != IdentifierCategory::Variable)
throw AstError{ name + " is already used" }; throw AstError{ name + " is already used" };
} }
std::size_t varIndex = m_context->currentEnv->variableTypes.Register(std::move(type), index); std::size_t varIndex = m_context->variableTypes.Register(std::move(type), index);
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(name), std::move(name),
varIndex, varIndex,
Identifier::Type::Variable IdentifierCategory::Variable
}); });
return varIndex; return varIndex;
} }
auto SanitizeVisitor::ResolveAlias(const IdentifierData* identifier) const -> const IdentifierData*
{
while (identifier->category == IdentifierCategory::Alias)
identifier = &m_context->aliases.Retrieve(identifier->index);
return identifier;
}
void SanitizeVisitor::ResolveFunctions() void SanitizeVisitor::ResolveFunctions()
{ {
// Once every function is known, we can evaluate function content // Once every function is known, we can evaluate function content
@ -1997,7 +2033,7 @@ namespace Nz::ShaderAst
std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex; std::size_t funcIndex = *pendingFunc.cloneNode->funcIndex;
for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i))
{ {
auto& targetFunc = m_context->currentEnv->functions.Retrieve(i); auto& targetFunc = m_context->functions.Retrieve(i);
targetFunc.calledByFunctions.UnboundedSet(funcIndex); targetFunc.calledByFunctions.UnboundedSet(funcIndex);
} }
@ -2006,13 +2042,13 @@ namespace Nz::ShaderAst
m_context->pendingFunctions.clear(); m_context->pendingFunctions.clear();
Bitset<> seen; Bitset<> seen;
for (const auto& [funcIndex, funcData] : m_context->currentEnv->functions.values) for (const auto& [funcIndex, funcData] : m_context->functions.values)
{ {
PropagateFunctionFlags(funcIndex, funcData.flags, seen); PropagateFunctionFlags(funcIndex, funcData.flags, seen);
seen.Clear(); seen.Clear();
} }
for (const auto& [funcIndex, funcData] : m_context->currentEnv->functions.values) for (const auto& [funcIndex, funcData] : m_context->functions.values)
{ {
if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment)
throw AstError{ "discard can only be used in the fragment stage" }; throw AstError{ "discard can only be used in the fragment stage" };
@ -2064,14 +2100,14 @@ namespace Nz::ShaderAst
std::size_t SanitizeVisitor::ResolveStruct(const IdentifierType& identifierType) std::size_t SanitizeVisitor::ResolveStruct(const IdentifierType& identifierType)
{ {
const Identifier* identifier = FindIdentifier(identifierType.name); const IdentifierData* identifierData = FindIdentifier(identifierType.name);
if (!identifier) if (!identifierData)
throw AstError{ "unknown identifier " + identifierType.name }; throw AstError{ "unknown identifier " + identifierType.name };
if (identifier->type != Identifier::Type::Struct) if (identifierData->category != IdentifierCategory::Struct)
throw AstError{ identifierType.name + " is not a struct" }; throw AstError{ identifierType.name + " is not a struct" };
return identifier->index; return identifierData->index;
} }
std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType) std::size_t SanitizeVisitor::ResolveStruct(const StructType& structType)
@ -2091,7 +2127,7 @@ namespace Nz::ShaderAst
std::size_t typeIndex = std::get<Type>(exprType).typeIndex; std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
const auto& type = m_context->currentEnv->types.Retrieve(typeIndex); const auto& type = m_context->types.Retrieve(typeIndex);
if (std::holds_alternative<PartialType>(type)) if (std::holds_alternative<PartialType>(type))
throw AstError{ "full type expected" }; throw AstError{ "full type expected" };
@ -2186,6 +2222,21 @@ namespace Nz::ShaderAst
return node; return node;
} }
void SanitizeVisitor::Validate(DeclareAliasStatement& node)
{
if (node.name.empty())
throw std::runtime_error("invalid alias name");
ExpressionType exprType = GetExpressionType(*node.expression);
if (IsStructType(exprType))
{
std::size_t structIndex = ResolveStruct(exprType);
node.aliasIndex = RegisterAlias(node.name, { structIndex, IdentifierCategory::Struct }, node.aliasIndex);
}
else
throw AstError{ "for now, only structs can be aliased" };
}
void SanitizeVisitor::Validate(WhileStatement& node) void SanitizeVisitor::Validate(WhileStatement& node)
{ {
if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean }) if (GetExpressionType(*node.condition) != ExpressionType{ PrimitiveType::Boolean })
@ -2198,7 +2249,7 @@ namespace Nz::ShaderAst
if (IsTypeExpression(exprType)) if (IsTypeExpression(exprType))
{ {
std::size_t typeIndex = std::get<Type>(exprType).typeIndex; std::size_t typeIndex = std::get<Type>(exprType).typeIndex;
const auto& type = m_context->currentEnv->types.Retrieve(typeIndex); const auto& type = m_context->types.Retrieve(typeIndex);
if (!std::holds_alternative<PartialType>(type)) if (!std::holds_alternative<PartialType>(type))
throw std::runtime_error("only partial types can be specialized"); throw std::runtime_error("only partial types can be specialized");
@ -2291,7 +2342,7 @@ namespace Nz::ShaderAst
Int32 index = std::get<Int32>(constantExpr.value); Int32 index = std::get<Int32>(constantExpr.value);
std::size_t structIndex = ResolveStruct(exprType); std::size_t structIndex = ResolveStruct(exprType);
const StructDescription* s = m_context->currentEnv->structs.Retrieve(structIndex); const StructDescription* s = m_context->structs.Retrieve(structIndex);
exprType = ResolveType(s->members[index].type); exprType = ResolveType(s->members[index].type);
} }
@ -2365,7 +2416,7 @@ namespace Nz::ShaderAst
assert(std::holds_alternative<FunctionType>(targetFuncType)); assert(std::holds_alternative<FunctionType>(targetFuncType));
std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex; std::size_t targetFuncIndex = std::get<FunctionType>(targetFuncType).funcIndex;
auto& funcData = m_context->currentEnv->functions.Retrieve(targetFuncIndex); auto& funcData = m_context->functions.Retrieve(targetFuncIndex);
const DeclareFunctionStatement* referenceDeclaration = funcData.node; const DeclareFunctionStatement* referenceDeclaration = funcData.node;
@ -2482,9 +2533,9 @@ namespace Nz::ShaderAst
if (m_context->options.makeVariableNameUnique) if (m_context->options.makeVariableNameUnique)
{ {
// Since we are registered, FindIdentifier will find us // Since we are registered, FindIdentifier will find us
auto IgnoreOurself = [varIndex = *node.varIndex](const Identifier& identifier) auto IgnoreOurself = [varIndex = *node.varIndex](const IdentifierData& identifierData)
{ {
if (identifier.type == Identifier::Type::Variable && identifier.index == varIndex) if (identifierData.category == IdentifierCategory::Variable && identifierData.index == varIndex)
return false; return false;
return true; return true;
@ -2728,7 +2779,7 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(VariableExpression& node) void SanitizeVisitor::Validate(VariableExpression& node)
{ {
node.cachedExpressionType = m_context->currentEnv->variableTypes.Retrieve(node.variableId); node.cachedExpressionType = m_context->variableTypes.Retrieve(node.variableId);
} }
ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr) ExpressionType SanitizeVisitor::ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr)

View File

@ -1017,6 +1017,11 @@ namespace Nz
} }
} }
void GlslWriter::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
{
/* nothing to do */
}
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/) void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{ {
/* nothing to do */ /* nothing to do */

View File

@ -768,6 +768,18 @@ namespace Nz
node.statement->Visit(*this); node.statement->Visit(*this);
} }
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{
throw std::runtime_error("TODO"); //< missing registering
assert(node.aliasIndex);
Append("alias ", node.name, " = ");
assert(node.expression);
node.expression->Visit(*this);
AppendLine(";");
}
void LangWriter::Visit(ShaderAst::DeclareConstStatement& node) void LangWriter::Visit(ShaderAst::DeclareConstStatement& node)
{ {
assert(node.constIndex); assert(node.constIndex);
@ -780,7 +792,7 @@ namespace Nz
node.expression->Visit(*this); node.expression->Visit(*this);
} }
Append(";"); AppendLine(";");
} }
void LangWriter::Visit(ShaderAst::ConstantValueExpression& node) void LangWriter::Visit(ShaderAst::ConstantValueExpression& node)

View File

@ -347,6 +347,7 @@ namespace Nz::ShaderLang
auto& importedModule = m_context->module->importedModules.emplace_back(); auto& importedModule = m_context->module->importedModules.emplace_back();
importedModule.module = std::move(module); importedModule.module = std::move(module);
importedModule.identifier = identifier;
} }
else else
{ {

View File

@ -598,6 +598,11 @@ namespace Nz
}, node.value); }, node.value);
} }
void SpirvAstVisitor::Visit(ShaderAst::DeclareAliasStatement& /*node*/)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/) void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{ {
/* nothing to do */ /* nothing to do */

View File

@ -4,11 +4,11 @@ std::filesystem::path GetResourceDir()
{ {
static std::filesystem::path resourceDir = [] static std::filesystem::path resourceDir = []
{ {
std::filesystem::path resourceDir = "resources"; std::filesystem::path dir = "resources";
if (!std::filesystem::is_directory(resourceDir) && std::filesystem::is_directory(".." / resourceDir)) if (!std::filesystem::is_directory(dir) && std::filesystem::is_directory(".." / dir))
return ".." / resourceDir; return ".." / dir;
else else
return resourceDir; return dir;
}(); }();
return resourceDir; return resourceDir;