Modules are workings \o/

This commit is contained in:
Jérôme Leclercq 2022-03-08 20:26:02 +01:00
parent 83d26e209e
commit be9bdc4705
29 changed files with 742 additions and 256 deletions

View File

@ -9,16 +9,26 @@
NAZARA_REQUEST_DEDICATED_GPU() NAZARA_REQUEST_DEDICATED_GPU()
const char moduleSource[] = R"( const char barModuleSource[] = R"(
[nzsl_version("1.0")] [nzsl_version("1.0")]
[uuid("4BB09DEE-F70A-442E-859F-E8F2F3F8583D")]
module; module;
fn dummy() {} fn dummy() {}
[export]
[layout(std140)] [layout(std140)]
struct Bar struct Bar
{ {
} }
)";
const char dataModuleSource[] = R"(
[nzsl_version("1.0")]
[uuid("E49DC9AD-469C-462C-9719-A6F012372029")]
module;
import Test/Bar;
struct Foo {} struct Foo {}
@ -37,14 +47,15 @@ const char shaderSource[] = R"(
[nzsl_version("1.0")] [nzsl_version("1.0")]
module; module;
import Test/module_test; import Test/Data;
import Test/Bar;
option red: bool = false; option red: bool = false;
[set(0)] [set(0)]
external external
{ {
[binding(0)] viewerData: uniform[Data], [binding(0)] viewerData: uniform[Data]
} }
[set(1)] [set(1)]
@ -138,10 +149,12 @@ int main()
if (modulePath[0] != "Test") if (modulePath[0] != "Test")
return {}; return {};
if (modulePath[1] != "module_test") if (modulePath[1] == "Bar")
return Nz::ShaderLang::Parse(std::string_view(barModuleSource, sizeof(barModuleSource)));
else if (modulePath[1] == "Data")
return Nz::ShaderLang::Parse(std::string_view(dataModuleSource, sizeof(dataModuleSource)));
else
return {}; return {};
return Nz::ShaderLang::Parse(std::string_view(moduleSource, sizeof(moduleSource)));
}; };
shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt); shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt);
@ -152,10 +165,12 @@ int main()
} }
Nz::LangWriter langWriter; Nz::LangWriter langWriter;
std::cout << langWriter.Generate(*shaderModule) << std::endl; std::string output = langWriter.Generate(*shaderModule);
std::cout << output << std::endl;
assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output)));
Nz::ShaderWriter::States states; Nz::ShaderWriter::States states;
states.optimize = true; states.optimize = false;
auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states); auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states);
if (!fragVertShader) if (!fragVertShader)

View File

@ -50,8 +50,11 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(ConditionalExpression& node); virtual ExpressionPtr Clone(ConditionalExpression& node);
virtual ExpressionPtr Clone(ConstantExpression& node); virtual ExpressionPtr Clone(ConstantExpression& node);
virtual ExpressionPtr Clone(ConstantValueExpression& node); virtual ExpressionPtr Clone(ConstantValueExpression& node);
virtual ExpressionPtr Clone(FunctionExpression& node);
virtual ExpressionPtr Clone(IdentifierExpression& node); virtual ExpressionPtr Clone(IdentifierExpression& node);
virtual ExpressionPtr Clone(IntrinsicExpression& node); virtual ExpressionPtr Clone(IntrinsicExpression& node);
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
virtual ExpressionPtr Clone(StructTypeExpression& node);
virtual ExpressionPtr Clone(SwizzleExpression& node); virtual ExpressionPtr Clone(SwizzleExpression& node);
virtual ExpressionPtr Clone(VariableExpression& node); virtual ExpressionPtr Clone(VariableExpression& node);
virtual ExpressionPtr Clone(UnaryExpression& node); virtual ExpressionPtr Clone(UnaryExpression& node);

View File

@ -42,8 +42,11 @@ namespace Nz::ShaderAst
inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs); inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs);
inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs); inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs);
inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs); inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
inline bool Compare(const FunctionExpression& lhs, const FunctionExpression& rhs);
inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs); inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs);
inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs); inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs);
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs); inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs); inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs);
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs); inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);

View File

@ -342,6 +342,14 @@ namespace Nz::ShaderAst
return true; return true;
} }
inline bool Compare(const FunctionExpression& lhs, const FunctionExpression& rhs)
{
if (!Compare(lhs.funcId, rhs.funcId))
return false;
return true;
}
inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs) inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs)
{ {
if (!Compare(lhs.identifier, rhs.identifier)) if (!Compare(lhs.identifier, rhs.identifier))
@ -361,6 +369,22 @@ namespace Nz::ShaderAst
return true; return true;
} }
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs)
{
if (!Compare(lhs.intrinsicId, rhs.intrinsicId))
return false;
return true;
}
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs)
{
if (!Compare(lhs.structTypeId, rhs.structTypeId))
return false;
return true;
}
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs) inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs)
{ {
if (!Compare(lhs.componentCount, rhs.componentCount)) if (!Compare(lhs.componentCount, rhs.componentCount))

View File

@ -38,8 +38,11 @@ NAZARA_SHADERAST_EXPRESSION(CastExpression)
NAZARA_SHADERAST_EXPRESSION(ConditionalExpression) NAZARA_SHADERAST_EXPRESSION(ConditionalExpression)
NAZARA_SHADERAST_EXPRESSION(ConstantExpression) NAZARA_SHADERAST_EXPRESSION(ConstantExpression)
NAZARA_SHADERAST_EXPRESSION(ConstantValueExpression) NAZARA_SHADERAST_EXPRESSION(ConstantValueExpression)
NAZARA_SHADERAST_EXPRESSION(FunctionExpression)
NAZARA_SHADERAST_EXPRESSION(IdentifierExpression) NAZARA_SHADERAST_EXPRESSION(IdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression) NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression) NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_EXPRESSION(VariableExpression) NAZARA_SHADERAST_EXPRESSION(VariableExpression)
NAZARA_SHADERAST_EXPRESSION(UnaryExpression) NAZARA_SHADERAST_EXPRESSION(UnaryExpression)

View File

@ -30,8 +30,11 @@ namespace Nz::ShaderAst
void Visit(ConditionalExpression& node) override; void Visit(ConditionalExpression& node) override;
void Visit(ConstantValueExpression& node) override; void Visit(ConstantValueExpression& node) override;
void Visit(ConstantExpression& node) override; void Visit(ConstantExpression& node) override;
void Visit(FunctionExpression& node) override;
void Visit(IdentifierExpression& node) override; void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override; void Visit(IntrinsicExpression& node) override;
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override; void Visit(SwizzleExpression& node) override;
void Visit(VariableExpression& node) override; void Visit(VariableExpression& node) override;
void Visit(UnaryExpression& node) override; void Visit(UnaryExpression& node) override;

View File

@ -33,8 +33,11 @@ namespace Nz::ShaderAst
void Serialize(ConstantExpression& node); void Serialize(ConstantExpression& node);
void Serialize(ConditionalExpression& node); void Serialize(ConditionalExpression& node);
void Serialize(ConstantValueExpression& node); void Serialize(ConstantValueExpression& node);
void Serialize(FunctionExpression& node);
void Serialize(IdentifierExpression& node); void Serialize(IdentifierExpression& node);
void Serialize(IntrinsicExpression& node); void Serialize(IntrinsicExpression& node);
void Serialize(IntrinsicFunctionExpression& node);
void Serialize(StructTypeExpression& node);
void Serialize(SwizzleExpression& node); void Serialize(SwizzleExpression& node);
void Serialize(VariableExpression& node); void Serialize(VariableExpression& node);
void Serialize(UnaryExpression& node); void Serialize(UnaryExpression& node);

View File

@ -41,8 +41,11 @@ namespace Nz::ShaderAst
void Visit(ConditionalExpression& node) override; void Visit(ConditionalExpression& node) override;
void Visit(ConstantValueExpression& node) override; void Visit(ConstantValueExpression& node) override;
void Visit(ConstantExpression& node) override; void Visit(ConstantExpression& node) override;
void Visit(FunctionExpression& node) override;
void Visit(IdentifierExpression& node) override; void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override; void Visit(IntrinsicExpression& node) override;
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override; void Visit(SwizzleExpression& node) override;
void Visit(VariableExpression& node) override; void Visit(VariableExpression& node) override;
void Visit(UnaryExpression& node) override; void Visit(UnaryExpression& node) override;

View File

@ -16,6 +16,9 @@ namespace Nz::ShaderAst
inline ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::Config& config) inline ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::Config& config)
{ {
DependencyCheckerVisitor dependencyVisitor; DependencyCheckerVisitor dependencyVisitor;
for (const auto& importedModule : shaderModule.importedModules)
dependencyVisitor.Process(*importedModule.module->rootNode, config);
dependencyVisitor.Process(*shaderModule.rootNode, config); dependencyVisitor.Process(*shaderModule.rootNode, config);
dependencyVisitor.Resolve(); dependencyVisitor.Resolve();

View File

@ -45,7 +45,8 @@ namespace Nz::ShaderAst
StatementPtr Clone(DeclareStructStatement& node) override; StatementPtr Clone(DeclareStructStatement& node) override;
StatementPtr Clone(DeclareVariableStatement& node) override; StatementPtr Clone(DeclareVariableStatement& node) override;
ExpressionPtr Clone(CallFunctionExpression& node) override; ExpressionPtr Clone(FunctionExpression& node) override;
ExpressionPtr Clone(StructTypeExpression& node) override;
ExpressionPtr Clone(VariableExpression& node) override; ExpressionPtr Clone(VariableExpression& node) override;
void HandleType(ExpressionValue<ExpressionType>& exprType); void HandleType(ExpressionValue<ExpressionType>& exprType);

View File

@ -156,6 +156,14 @@ namespace Nz::ShaderAst
ShaderAst::ConstantValue value; ShaderAst::ConstantValue value;
}; };
struct NAZARA_SHADER_API FunctionExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t funcId;
};
struct NAZARA_SHADER_API IdentifierExpression : Expression struct NAZARA_SHADER_API IdentifierExpression : Expression
{ {
NodeType GetType() const override; NodeType GetType() const override;
@ -173,6 +181,22 @@ namespace Nz::ShaderAst
IntrinsicType intrinsic; IntrinsicType intrinsic;
}; };
struct NAZARA_SHADER_API IntrinsicFunctionExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t intrinsicId;
};
struct NAZARA_SHADER_API StructTypeExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t structTypeId;
};
struct NAZARA_SHADER_API SwizzleExpression : Expression struct NAZARA_SHADER_API SwizzleExpression : Expression
{ {
NodeType GetType() const override; NodeType GetType() const override;

View File

@ -108,6 +108,8 @@ namespace Nz::ShaderAst
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const; TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
Expression& MandatoryExpr(const ExpressionPtr& node) const; Expression& MandatoryExpr(const ExpressionPtr& node) const;
Statement& MandatoryStatement(const StatementPtr& node) const; Statement& MandatoryStatement(const StatementPtr& node) const;

View File

@ -85,7 +85,7 @@ namespace Nz
void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node); void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node);
void HandleInOut(); void HandleInOut();
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName); void RegisterVariable(std::size_t varIndex, std::string varName);
void ScopeVisit(ShaderAst::Statement& node); void ScopeVisit(ShaderAst::Statement& node);

View File

@ -43,10 +43,12 @@ namespace Nz
struct DepthWriteAttribute; struct DepthWriteAttribute;
struct EarlyFragmentTestsAttribute; struct EarlyFragmentTestsAttribute;
struct EntryAttribute; struct EntryAttribute;
struct LangVersionAttribute;
struct LayoutAttribute; struct LayoutAttribute;
struct LocationAttribute; struct LocationAttribute;
struct SetAttribute; struct SetAttribute;
struct UnrollAttribute; struct UnrollAttribute;
struct UuidAttribute;
void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type); void Append(const ShaderAst::ExpressionType& type);
@ -68,18 +70,21 @@ namespace Nz
template<typename... Args> void AppendAttributes(bool appendLine, Args&&... params); template<typename... Args> void AppendAttributes(bool appendLine, Args&&... params);
template<typename T> void AppendAttributesInternal(bool& first, const T& param); template<typename T> void AppendAttributesInternal(bool& first, const T& param);
template<typename T1, typename T2, typename... Rest> void AppendAttributesInternal(bool& first, const T1& firstParam, const T2& secondParam, Rest&&... params); template<typename T1, typename T2, typename... Rest> void AppendAttributesInternal(bool& first, const T1& firstParam, const T2& secondParam, Rest&&... params);
void AppendAttribute(BindingAttribute binding); void AppendAttribute(BindingAttribute attribute);
void AppendAttribute(BuiltinAttribute builtin); void AppendAttribute(BuiltinAttribute attribute);
void AppendAttribute(DepthWriteAttribute depthWrite); void AppendAttribute(DepthWriteAttribute attribute);
void AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests); void AppendAttribute(EarlyFragmentTestsAttribute attribute);
void AppendAttribute(EntryAttribute entry); void AppendAttribute(EntryAttribute attribute);
void AppendAttribute(LayoutAttribute layout); void AppendAttribute(LangVersionAttribute attribute);
void AppendAttribute(LocationAttribute location); void AppendAttribute(LayoutAttribute attribute);
void AppendAttribute(SetAttribute set); void AppendAttribute(LocationAttribute attribute);
void AppendAttribute(UnrollAttribute unroll); void AppendAttribute(SetAttribute seattributet);
void AppendAttribute(UnrollAttribute attribute);
void AppendAttribute(UuidAttribute attribute);
void AppendComment(const std::string& section); void AppendComment(const std::string& section);
void AppendCommentSection(const std::string& section); void AppendCommentSection(const std::string& section);
void AppendHeader(); void AppendHeader();
template<typename T> void AppendIdentifier(const T& map, std::size_t id);
void AppendLine(const std::string& txt = {}); void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params); template<typename... Args> void AppendLine(Args&&... params);
void AppendStatementList(std::vector<ShaderAst::StatementPtr>& statements); void AppendStatementList(std::vector<ShaderAst::StatementPtr>& statements);
@ -88,7 +93,7 @@ namespace Nz
void LeaveScope(bool skipLine = true); void LeaveScope(bool skipLine = true);
void RegisterConstant(std::size_t constantIndex, std::string constantName); void RegisterConstant(std::size_t constantIndex, std::string constantName);
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); void RegisterStruct(std::size_t structIndex, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName); void RegisterVariable(std::size_t varIndex, std::string varName);
void ScopeVisit(ShaderAst::Statement& node); void ScopeVisit(ShaderAst::Statement& node);
@ -104,6 +109,7 @@ namespace Nz
void Visit(ShaderAst::ConstantValueExpression& node) override; void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::ConstantExpression& node) override; void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override; void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::StructTypeExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override; void Visit(ShaderAst::VariableExpression& node) override;
void Visit(ShaderAst::UnaryExpression& node) override; void Visit(ShaderAst::UnaryExpression& node) override;

View File

@ -127,6 +127,11 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ForEachStatement> operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const; inline std::unique_ptr<ShaderAst::ForEachStatement> operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const;
}; };
struct Function
{
inline std::unique_ptr<ShaderAst::FunctionExpression> operator()(std::size_t funcId) const;
};
struct Identifier struct Identifier
{ {
inline std::unique_ptr<ShaderAst::IdentifierExpression> operator()(std::string name) const; inline std::unique_ptr<ShaderAst::IdentifierExpression> operator()(std::string name) const;
@ -142,6 +147,11 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::IntrinsicExpression> operator()(ShaderAst::IntrinsicType intrinsicType, std::vector<ShaderAst::ExpressionPtr> parameters) const; inline std::unique_ptr<ShaderAst::IntrinsicExpression> operator()(ShaderAst::IntrinsicType intrinsicType, std::vector<ShaderAst::ExpressionPtr> parameters) const;
}; };
struct IntrinsicFunction
{
inline std::unique_ptr<ShaderAst::IntrinsicFunctionExpression> operator()(std::size_t intrinsicFunctionId, ShaderAst::IntrinsicType intrinsicType) const;
};
struct Multi struct Multi
{ {
inline std::unique_ptr<ShaderAst::MultiStatement> operator()(std::vector<ShaderAst::StatementPtr> statements = {}) const; inline std::unique_ptr<ShaderAst::MultiStatement> operator()(std::vector<ShaderAst::StatementPtr> statements = {}) const;
@ -163,6 +173,11 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ScopedStatement> operator()(ShaderAst::StatementPtr statement) const; inline std::unique_ptr<ShaderAst::ScopedStatement> operator()(ShaderAst::StatementPtr statement) const;
}; };
struct StructType
{
inline std::unique_ptr<ShaderAst::StructTypeExpression> operator()(std::size_t structTypeId) const;
};
struct Swizzle struct Swizzle
{ {
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const; inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const;
@ -206,13 +221,16 @@ namespace Nz::ShaderBuilder
constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard; constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard;
constexpr Impl::For For; constexpr Impl::For For;
constexpr Impl::ForEach ForEach; constexpr Impl::ForEach ForEach;
constexpr Impl::Function Function;
constexpr Impl::Identifier Identifier; constexpr Impl::Identifier Identifier;
constexpr Impl::IntrinsicFunction IntrinsicFunction;
constexpr Impl::Import Import; constexpr Impl::Import Import;
constexpr Impl::Intrinsic Intrinsic; constexpr Impl::Intrinsic Intrinsic;
constexpr Impl::Multi MultiStatement; constexpr Impl::Multi MultiStatement;
constexpr Impl::NoParam<ShaderAst::NoOpStatement> NoOp; constexpr Impl::NoParam<ShaderAst::NoOpStatement> NoOp;
constexpr Impl::Return Return; constexpr Impl::Return Return;
constexpr Impl::Scoped Scoped; constexpr Impl::Scoped Scoped;
constexpr Impl::StructType StructType;
constexpr Impl::Swizzle Swizzle; constexpr Impl::Swizzle Swizzle;
constexpr Impl::Unary Unary; constexpr Impl::Unary Unary;
constexpr Impl::Variable Variable; constexpr Impl::Variable Variable;

View File

@ -315,7 +315,7 @@ namespace Nz::ShaderBuilder
return expressionStatementNode; return expressionStatementNode;
} }
inline std::unique_ptr<ShaderAst::ForStatement> Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const inline std::unique_ptr<ShaderAst::ForStatement> Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const
{ {
auto forNode = std::make_unique<ShaderAst::ForStatement>(); auto forNode = std::make_unique<ShaderAst::ForStatement>();
forNode->fromExpr = std::move(fromExpression); forNode->fromExpr = std::move(fromExpression);
@ -326,7 +326,7 @@ namespace Nz::ShaderBuilder
return forNode; return forNode;
} }
inline std::unique_ptr<ShaderAst::ForStatement> Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const inline std::unique_ptr<ShaderAst::ForStatement> Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const
{ {
auto forNode = std::make_unique<ShaderAst::ForStatement>(); auto forNode = std::make_unique<ShaderAst::ForStatement>();
forNode->fromExpr = std::move(fromExpression); forNode->fromExpr = std::move(fromExpression);
@ -348,6 +348,15 @@ namespace Nz::ShaderBuilder
return forEachNode; return forEachNode;
} }
inline std::unique_ptr<ShaderAst::FunctionExpression> Impl::Function::operator()(std::size_t funcId) const
{
auto intrinsicTypeExpr = std::make_unique<ShaderAst::FunctionExpression>();
intrinsicTypeExpr->cachedExpressionType = ShaderAst::FunctionType{ funcId };
intrinsicTypeExpr->funcId = funcId;
return intrinsicTypeExpr;
}
inline std::unique_ptr<ShaderAst::IdentifierExpression> Impl::Identifier::operator()(std::string name) const inline std::unique_ptr<ShaderAst::IdentifierExpression> Impl::Identifier::operator()(std::string name) const
{ {
auto identifierNode = std::make_unique<ShaderAst::IdentifierExpression>(); auto identifierNode = std::make_unique<ShaderAst::IdentifierExpression>();
@ -373,6 +382,15 @@ namespace Nz::ShaderBuilder
return intrinsicExpression; return intrinsicExpression;
} }
inline std::unique_ptr<ShaderAst::IntrinsicFunctionExpression> Impl::IntrinsicFunction::operator()(std::size_t intrinsicFunctionId, ShaderAst::IntrinsicType intrinsicType) const
{
auto intrinsicTypeExpr = std::make_unique<ShaderAst::IntrinsicFunctionExpression>();
intrinsicTypeExpr->cachedExpressionType = ShaderAst::IntrinsicFunctionType{ intrinsicType };
intrinsicTypeExpr->intrinsicId = intrinsicFunctionId;
return intrinsicTypeExpr;
}
inline std::unique_ptr<ShaderAst::MultiStatement> Impl::Multi::operator()(std::vector<ShaderAst::StatementPtr> statements) const inline std::unique_ptr<ShaderAst::MultiStatement> Impl::Multi::operator()(std::vector<ShaderAst::StatementPtr> statements) const
{ {
auto multiStatement = std::make_unique<ShaderAst::MultiStatement>(); auto multiStatement = std::make_unique<ShaderAst::MultiStatement>();
@ -403,6 +421,15 @@ namespace Nz::ShaderBuilder
return scopedNode; return scopedNode;
} }
inline std::unique_ptr<ShaderAst::StructTypeExpression> Impl::StructType::operator()(std::size_t structTypeId) const
{
auto structTypeExpr = std::make_unique<ShaderAst::StructTypeExpression>();
structTypeExpr->cachedExpressionType = ShaderAst::StructType{ structTypeId };
structTypeExpr->structTypeId = structTypeId;
return structTypeExpr;
}
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const
{ {
assert(componentCount > 0); assert(componentCount > 0);

View File

@ -86,6 +86,7 @@ namespace Nz::ShaderLang
void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue); void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue);
// Statements // Statements
ShaderAst::StatementPtr ParseAliasDeclaration();
ShaderAst::StatementPtr ParseBranchStatement(); ShaderAst::StatementPtr ParseBranchStatement();
ShaderAst::StatementPtr ParseConstStatement(); ShaderAst::StatementPtr ParseConstStatement();
ShaderAst::StatementPtr ParseDiscardStatement(); ShaderAst::StatementPtr ParseDiscardStatement();
@ -130,6 +131,7 @@ namespace Nz::ShaderLang
std::size_t tokenIndex = 0; std::size_t tokenIndex = 0;
ShaderAst::ModulePtr module; ShaderAst::ModulePtr module;
const Token* tokens; const Token* tokens;
bool parsingImportedModule = false;
}; };
Context* m_context; Context* m_context;

View File

@ -12,6 +12,7 @@
#define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X) #define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X)
#endif #endif
NAZARA_SHADERLANG_TOKEN(Alias)
NAZARA_SHADERLANG_TOKEN(Arrow) NAZARA_SHADERLANG_TOKEN(Arrow)
NAZARA_SHADERLANG_TOKEN(Assign) NAZARA_SHADERLANG_TOKEN(Assign)
NAZARA_SHADERLANG_TOKEN(BoolFalse) NAZARA_SHADERLANG_TOKEN(BoolFalse)

View File

@ -405,6 +405,16 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
ExpressionPtr AstCloner::Clone(FunctionExpression& node)
{
auto clone = std::make_unique<FunctionExpression>();
clone->funcId = node.funcId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(IdentifierExpression& node) ExpressionPtr AstCloner::Clone(IdentifierExpression& node)
{ {
auto clone = std::make_unique<IdentifierExpression>(); auto clone = std::make_unique<IdentifierExpression>();
@ -429,6 +439,26 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
ExpressionPtr AstCloner::Clone(IntrinsicFunctionExpression& node)
{
auto clone = std::make_unique<IntrinsicFunctionExpression>();
clone->intrinsicId = node.intrinsicId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(StructTypeExpression& node)
{
auto clone = std::make_unique<StructTypeExpression>();
clone->structTypeId = node.structTypeId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(SwizzleExpression& node) ExpressionPtr AstCloner::Clone(SwizzleExpression& node)
{ {
auto clone = std::make_unique<SwizzleExpression>(); auto clone = std::make_unique<SwizzleExpression>();

View File

@ -72,6 +72,11 @@ namespace Nz::ShaderAst
/* Nothing to do */ /* Nothing to do */
} }
void AstRecursiveVisitor::Visit(FunctionExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/) void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/)
{ {
/* Nothing to do */ /* Nothing to do */
@ -83,6 +88,16 @@ namespace Nz::ShaderAst
param->Visit(*this); param->Visit(*this);
} }
void AstRecursiveVisitor::Visit(IntrinsicFunctionExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(StructTypeExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(SwizzleExpression& node) void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{ {
if (node.expression) if (node.expression)

View File

@ -153,6 +153,21 @@ namespace Nz::ShaderAst
Node(param); Node(param);
} }
void AstSerializerBase::Serialize(IntrinsicFunctionExpression& node)
{
SizeT(node.intrinsicId);
}
void AstSerializerBase::Serialize(StructTypeExpression& node)
{
SizeT(node.structTypeId);
}
void AstSerializerBase::Serialize(FunctionExpression& node)
{
SizeT(node.funcId);
}
void AstSerializerBase::Serialize(SwizzleExpression& node) void AstSerializerBase::Serialize(SwizzleExpression& node)
{ {
SizeT(node.componentCount); SizeT(node.componentCount);

View File

@ -72,6 +72,11 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::LValue; m_expressionCategory = ExpressionCategory::LValue;
} }
void ShaderAstValueCategory::Visit(FunctionExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/) void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/)
{ {
m_expressionCategory = ExpressionCategory::LValue; m_expressionCategory = ExpressionCategory::LValue;
@ -82,6 +87,16 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::RValue; m_expressionCategory = ExpressionCategory::RValue;
} }
void ShaderAstValueCategory::Visit(IntrinsicFunctionExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(StructTypeExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(SwizzleExpression& node) void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{ {
if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1) if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1)

View File

@ -137,20 +137,22 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
ExpressionPtr IndexRemapperVisitor::Clone(CallFunctionExpression& node) ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node)
{ {
CallFunctionExpressionPtr clone = static_unique_pointer_cast<CallFunctionExpression>(AstCloner::Clone(node)); FunctionExpressionPtr clone = static_unique_pointer_cast<FunctionExpression>(AstCloner::Clone(node));
const auto& targetFuncType = GetExpressionType(*node.targetFunction); assert(clone->funcId);
if (std::holds_alternative<FunctionType>(targetFuncType)) clone->funcId = Retrieve(m_context->newFuncIndices, clone->funcId);
{
const auto& funcType = std::get<FunctionType>(targetFuncType);
FunctionType newFunc; return clone;
newFunc.funcIndex = Retrieve(m_context->newFuncIndices, funcType.funcIndex); }
clone->cachedExpressionType = ExpressionType{ newFunc }; //< FIXME We should add FunctionExpression like VariableExpression to handle this
}
ExpressionPtr IndexRemapperVisitor::Clone(StructTypeExpression& node)
{
StructTypeExpressionPtr clone = static_unique_pointer_cast<StructTypeExpression>(AstCloner::Clone(node));
assert(clone->structTypeId);
clone->structTypeId = Retrieve(m_context->newStructIndices, clone->structTypeId);
return clone; return clone;
} }

View File

@ -7,6 +7,7 @@
#include <Nazara/Core/CallOnExit.hpp> #include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackArray.hpp> #include <Nazara/Core/StackArray.hpp>
#include <Nazara/Core/StackVector.hpp> #include <Nazara/Core/StackVector.hpp>
#include <Nazara/Core/Hash/SHA256.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp> #include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp> #include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
#include <Nazara/Shader/Ast/AstExportVisitor.hpp> #include <Nazara/Shader/Ast/AstExportVisitor.hpp>
@ -114,6 +115,7 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Environment struct SanitizeVisitor::Environment
{ {
Uuid moduleId;
std::shared_ptr<Environment> parentEnv; std::shared_ptr<Environment> parentEnv;
std::vector<Identifier> identifiersInScope; std::vector<Identifier> identifiersInScope;
std::vector<Scope> scopes; std::vector<Scope> scopes;
@ -121,21 +123,31 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Context struct SanitizeVisitor::Context
{ {
struct ModuleData
{
std::unordered_map<Uuid, DependencyCheckerVisitor::UsageSet> exportedSetByModule;
std::shared_ptr<Environment> environment;
std::unique_ptr<DependencyCheckerVisitor> dependenciesVisitor;
};
struct PendingFunction struct PendingFunction
{ {
DeclareFunctionStatement* cloneNode; DeclareFunctionStatement* cloneNode;
const DeclareFunctionStatement* node; const DeclareFunctionStatement* node;
}; };
static constexpr std::size_t ModuleIdSentinel = std::numeric_limits<std::size_t>::max();
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {}; std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::vector<std::shared_ptr<Environment>> moduleEnvironments; std::vector<ModuleData> modules;
std::vector<PendingFunction> pendingFunctions; std::vector<PendingFunction> pendingFunctions;
std::vector<StatementPtr>* currentStatementList = nullptr; std::vector<StatementPtr>* currentStatementList = nullptr;
std::unordered_map<Uuid, std::size_t> moduleByUuid; std::unordered_map<Uuid, std::size_t> moduleByUuid;
std::unordered_set<std::string> declaredExternalVar; std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<UInt64> usedBindingIndexes; std::unordered_set<UInt64> usedBindingIndexes;
std::shared_ptr<Environment> globalEnv; std::shared_ptr<Environment> builtinEnv;
std::shared_ptr<Environment> currentEnv; std::shared_ptr<Environment> currentEnv;
std::shared_ptr<Environment> globalEnv;
IdentifierList<ConstantValue> constantValues; IdentifierList<ConstantValue> constantValues;
IdentifierList<FunctionData> functions; IdentifierList<FunctionData> functions;
IdentifierList<IdentifierData> aliases; IdentifierList<IdentifierData> aliases;
@ -144,22 +156,57 @@ namespace Nz::ShaderAst
IdentifierList<StructDescription*> structs; IdentifierList<StructDescription*> structs;
IdentifierList<std::variant<ExpressionType, PartialType>> types; IdentifierList<std::variant<ExpressionType, PartialType>> types;
IdentifierList<ExpressionType> variableTypes; IdentifierList<ExpressionType> variableTypes;
ModulePtr currentModule;
Options options; Options options;
CurrentFunctionData* currentFunction = nullptr; CurrentFunctionData* currentFunction = nullptr;
}; };
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
{ {
Context currentContext;
currentContext.options = options;
ModulePtr clone = std::make_shared<Module>(module.metadata); ModulePtr clone = std::make_shared<Module>(module.metadata);
clone->importedModules = module.importedModules; clone->importedModules = module.importedModules;
Context currentContext;
currentContext.options = options;
currentContext.currentModule = clone;
m_context = &currentContext; m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; }); CallOnExit resetContext([&] { m_context = nullptr; });
// Register builtin env
m_context->builtinEnv = std::make_shared<Environment>();
m_context->currentEnv = m_context->builtinEnv;
RegisterBuiltin();
m_context->globalEnv = std::make_shared<Environment>(); m_context->globalEnv = std::make_shared<Environment>();
m_context->globalEnv->moduleId = clone->metadata->moduleId;
m_context->globalEnv->parentEnv = m_context->builtinEnv;
for (std::size_t moduleId = 0; moduleId < clone->importedModules.size(); ++moduleId)
{
auto moduleEnv = std::make_shared<Environment>();
moduleEnv->moduleId = clone->importedModules[moduleId].module->metadata->moduleId;
moduleEnv->parentEnv = m_context->builtinEnv;
m_context->currentEnv = moduleEnv;
// Previous modules are visibles
for (std::size_t previousModuleId = 0; previousModuleId < moduleId; ++previousModuleId)
RegisterModule(clone->importedModules[previousModuleId].identifier, previousModuleId);
auto& importedModule = clone->importedModules[moduleId];
importedModule.module->rootNode = SanitizeInternal(*importedModule.module->rootNode, error);
if (!importedModule.module->rootNode)
return {};
m_context->moduleByUuid[importedModule.module->metadata->moduleId] = moduleId;
auto& moduleData = m_context->modules.emplace_back();
moduleData.environment = std::move(moduleEnv);
m_context->currentEnv = m_context->globalEnv;
RegisterModule(importedModule.identifier, moduleId);
}
m_context->currentEnv = m_context->globalEnv; m_context->currentEnv = m_context->globalEnv;
clone->rootNode = SanitizeInternal(*module.rootNode, error); clone->rootNode = SanitizeInternal(*module.rootNode, error);
@ -211,7 +258,25 @@ namespace Nz::ShaderAst
if (node.identifiers.empty()) if (node.identifiers.empty())
throw AstError{ "AccessIdentifierExpression must have at least one identifier" }; throw AstError{ "AccessIdentifierExpression must have at least one identifier" };
ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr)); MandatoryExpr(node.expr);
// Handle module access (TODO: Add namespace expression?)
if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1)
{
auto& identifierExpr = static_cast<IdentifierExpression&>(*node.expr);
const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier);
if (identifierData && identifierData->category == IdentifierCategory::Module)
{
std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index);
const auto& env = *m_context->modules[moduleIndex].environment;
identifierData = FindIdentifier(env, node.identifiers.front());
if (identifierData)
return HandleIdentifier(identifierData);
}
}
ExpressionPtr indexedExpr = CloneExpression(node.expr);
for (const std::string& identifier : node.identifiers) for (const std::string& identifier : node.identifiers)
{ {
if (identifier.empty()) if (identifier.empty())
@ -393,7 +458,10 @@ namespace Nz::ShaderAst
if (!m_context->currentFunction) if (!m_context->currentFunction)
throw AstError{ "function calls must happen inside a function" }; throw AstError{ "function calls must happen inside a function" };
std::size_t targetFuncIndex = std::get<FunctionType>(targetExprType).funcIndex; if (targetExpr->GetType() != NodeType::FunctionExpression)
throw AstError{ "expected function expression" };
std::size_t targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
auto clone = std::make_unique<CallFunctionExpression>(); auto clone = std::make_unique<CallFunctionExpression>();
clone->targetFunction = std::move(targetExpr); clone->targetFunction = std::move(targetExpr);
@ -410,13 +478,18 @@ namespace Nz::ShaderAst
} }
else if (IsIntrinsicFunctionType(targetExprType)) else if (IsIntrinsicFunctionType(targetExprType))
{ {
if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression)
throw AstError{ "expected intrinsic function expression" };
std::size_t targetIntrinsicId = static_cast<IntrinsicFunctionExpression&>(*targetExpr).intrinsicId;
std::vector<ExpressionPtr> parameters; std::vector<ExpressionPtr> parameters;
parameters.reserve(node.parameters.size()); parameters.reserve(node.parameters.size());
for (const auto& param : node.parameters) for (const auto& param : node.parameters)
parameters.push_back(CloneExpression(param)); parameters.push_back(CloneExpression(param));
auto intrinsic = ShaderBuilder::Intrinsic(std::get<IntrinsicFunctionType>(targetExprType).intrinsic, std::move(parameters)); auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics.Retrieve(targetIntrinsicId), std::move(parameters));
Validate(*intrinsic); Validate(*intrinsic);
return intrinsic; return intrinsic;
@ -584,64 +657,7 @@ namespace Nz::ShaderAst
if (!identifierData) if (!identifierData)
throw AstError{ "unknown identifier " + node.identifier }; throw AstError{ "unknown identifier " + node.identifier };
switch (identifierData->category) return HandleIdentifier(identifierData);
{
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr;
constantExpr.constantId = identifierData->index;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}
case IdentifierCategory::Function:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = FunctionType{ identifierData->index };
return clone;
}
case IdentifierCategory::Intrinsic:
{
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index);
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType };
return clone;
}
case IdentifierCategory::Struct:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = StructType{ identifierData->index };
return clone;
}
case IdentifierCategory::Type:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = Type{ identifierData->index };
return clone;
}
case IdentifierCategory::Variable:
{
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index);
varExpr->variableId = identifierData->index;
return varExpr;
}
default:
throw AstError{ "unexpected identifier" };
}
} }
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
@ -872,6 +888,8 @@ namespace Nz::ShaderAst
if (node.returnType.HasValue()) if (node.returnType.HasValue())
clone->returnType = ResolveType(node.returnType); clone->returnType = ResolveType(node.returnType);
else
clone->returnType = ExpressionType{ NoType{} };
if (node.depthWrite.HasValue()) if (node.depthWrite.HasValue())
clone->depthWrite = ComputeExprValue(node.depthWrite); clone->depthWrite = ComputeExprValue(node.depthWrite);
@ -1360,54 +1378,99 @@ namespace Nz::ShaderAst
if (!targetModule) if (!targetModule)
throw AstError{ "module " + ModulePathAsString() + " not found" }; throw AstError{ "module " + ModulePathAsString() + " not found" };
targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString(); std::size_t moduleIndex;
m_context->currentEnv = m_context->moduleEnvironments.emplace_back(std::make_shared<Environment>()); const Uuid& moduleUuid = targetModule->metadata->moduleId;
CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; }); auto it = m_context->moduleByUuid.find(moduleUuid);
if (it == m_context->moduleByUuid.end())
{
m_context->moduleByUuid[moduleUuid] = Context::ModuleIdSentinel;
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata); // Generate module identifier (based on UUID)
const auto& moduleUuidBytes = moduleUuid.ToArray();
std::string error; SHA256Hash hasher;
sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error); hasher.Begin();
if (!sanitizedModule) hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size());
throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error }; hasher.End();
std::string identifier = "__" + hasher.End().ToHex().substr(0, 8);
// Load new module
auto moduleEnvironment = std::make_shared<Environment>();
moduleEnvironment->parentEnv = m_context->builtinEnv;
auto previousEnv = m_context->currentEnv;
m_context->currentEnv = moduleEnvironment;
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
std::string error;
sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error);
if (!sanitizedModule)
throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error };
moduleIndex = m_context->modules.size();
assert(m_context->modules.size() == moduleIndex);
auto& moduleData = m_context->modules.emplace_back();
moduleData.dependenciesVisitor = std::make_unique<DependencyCheckerVisitor>();
moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode);
moduleData.environment = std::move(moduleEnvironment);
assert(m_context->currentModule->importedModules.size() == moduleIndex);
auto& importedModule = m_context->currentModule->importedModules.emplace_back();
importedModule.identifier = identifier;
importedModule.module = std::move(sanitizedModule);
m_context->currentEnv = std::move(previousEnv);
RegisterModule(identifier, moduleIndex);
m_context->moduleByUuid[moduleUuid] = moduleIndex;
}
else
{
// Module has already been imported
moduleIndex = it->second;
if (moduleIndex == Context::ModuleIdSentinel)
throw AstError{ "circular import detected" };
}
auto& moduleData = m_context->modules[moduleIndex];
auto& exportedSet = moduleData.exportedSetByModule[m_context->currentEnv->moduleId];
// Extract exported nodes and their dependencies // Extract exported nodes and their dependencies
DependencyCheckerVisitor::Config depConfig; std::vector<DeclareAliasStatementPtr> aliasStatements;
depConfig.usedShaderStages.Clear();
DependencyCheckerVisitor moduleDependencies;
moduleDependencies.Process(*sanitizedModule->rootNode, depConfig);
DependencyCheckerVisitor::UsageSet exportedSet;
MultiStatementPtr aliasBlock = std::make_unique<MultiStatement>();
AstExportVisitor::Callbacks callbacks; AstExportVisitor::Callbacks callbacks;
callbacks.onExportedStruct = [&](DeclareStructStatement& node) callbacks.onExportedStruct = [&](DeclareStructStatement& node)
{ {
assert(node.structIndex); assert(node.structIndex);
moduleDependencies.MarkStructAsUsed(*node.structIndex); moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex);
exportedSet.usedStructs.UnboundedSet(*node.structIndex);
auto alias = Clone(node);
// TODO: DeclareAlias
aliasBlock->statements.emplace_back(std::move(alias));
if (!exportedSet.usedStructs.UnboundedTest(*node.structIndex))
{
exportedSet.usedStructs.UnboundedSet(*node.structIndex);
aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.description.name, ShaderBuilder::StructType(*node.structIndex)));
}
}; };
AstExportVisitor exportVisitor; AstExportVisitor exportVisitor;
exportVisitor.Visit(*sanitizedModule->rootNode, callbacks); exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
moduleDependencies.Resolve(); if (aliasStatements.empty())
return ShaderBuilder::NoOp();
//m_context-> // Register module and aliases
MultiStatementPtr aliasBlock = std::make_unique<MultiStatement>();
// Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor) for (auto& aliasPtr : aliasStatements)
//m_context->importUsage = remappedExportedSet; {
//CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); }); Validate(*aliasPtr);
aliasBlock->statements.push_back(std::move(aliasPtr));
}
return aliasBlock; return aliasBlock;
} }
@ -1559,6 +1622,74 @@ namespace Nz::ShaderAst
throw std::runtime_error("internal error"); throw std::runtime_error("internal error");
} }
ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData)
{
switch (identifierData->category)
{
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr;
constantExpr.constantId = identifierData->index;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}
case IdentifierCategory::Function:
{
// Replace IdentifierExpression by FunctionExpression
auto funcExpr = std::make_unique<FunctionExpression>();
funcExpr->cachedExpressionType = FunctionType{ identifierData->index }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type
funcExpr->funcId = identifierData->index;
return funcExpr;
}
case IdentifierCategory::Intrinsic:
{
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index);
// Replace IdentifierExpression by IntrinsicFunctionExpression
auto intrinsicExpr = std::make_unique<IntrinsicFunctionExpression>();
intrinsicExpr->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type
intrinsicExpr->intrinsicId = identifierData->index;
return intrinsicExpr;
}
case IdentifierCategory::Struct:
{
// Replace IdentifierExpression by StructTypeExpression
auto structExpr = std::make_unique<StructTypeExpression>();
structExpr->cachedExpressionType = StructType{ identifierData->index };
structExpr->structTypeId = identifierData->index;
return structExpr;
}
case IdentifierCategory::Type:
{
auto clone = ShaderBuilder::Identifier("dummy");
clone->cachedExpressionType = Type{ identifierData->index };
return clone;
}
case IdentifierCategory::Variable:
{
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index);
varExpr->variableId = identifierData->index;
return varExpr;
}
default:
throw AstError{ "unexpected identifier" };
}
}
Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const
{ {
if (!node) if (!node)
@ -1909,20 +2040,20 @@ namespace Nz::ShaderAst
return intrinsicIndex; return intrinsicIndex;
} }
std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex) std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t index)
{ {
if (FindIdentifier(moduleIdentifier)) if (FindIdentifier(moduleIdentifier))
throw AstError{ moduleIdentifier + " is already used" }; throw AstError{ moduleIdentifier + " is already used" };
std::size_t intrinsicIndex = m_context->moduleIndices.Register(moduleIndex); std::size_t moduleIndex = m_context->moduleIndices.Register(index);
m_context->currentEnv->identifiersInScope.push_back({ m_context->currentEnv->identifiersInScope.push_back({
std::move(moduleIdentifier), std::move(moduleIdentifier),
intrinsicIndex, moduleIndex,
IdentifierCategory::Module IdentifierCategory::Module
}); });
return intrinsicIndex; return moduleIndex;
} }
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index) std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index)
@ -2169,11 +2300,7 @@ namespace Nz::ShaderAst
MultiStatementPtr SanitizeVisitor::SanitizeInternal(MultiStatement& rootNode, std::string* error) MultiStatementPtr SanitizeVisitor::SanitizeInternal(MultiStatement& rootNode, std::string* error)
{ {
MultiStatementPtr output; MultiStatementPtr output;
PushScope(); //< Global scope
{ {
RegisterBuiltin();
// First pass, evaluate everything except function code // First pass, evaluate everything except function code
try try
{ {
@ -2196,7 +2323,6 @@ namespace Nz::ShaderAst
ResolveFunctions(); ResolveFunctions();
} }
PopScope();
return output; return output;
} }
@ -2261,7 +2387,7 @@ namespace Nz::ShaderAst
StackVector<TypeParameter> parameters = NazaraStackVector(TypeParameter, partialType.parameters.size()); StackVector<TypeParameter> parameters = NazaraStackVector(TypeParameter, partialType.parameters.size());
for (std::size_t i = 0; i < partialType.parameters.size(); ++i) for (std::size_t i = 0; i < partialType.parameters.size(); ++i)
{ {
ExpressionPtr indexExpr = CloneExpression(node.indices[i]); const ExpressionPtr& indexExpr = node.indices[i];
switch (partialType.parameters[i]) switch (partialType.parameters[i])
{ {
case TypeParameterCategory::ConstantValue: case TypeParameterCategory::ConstantValue:

View File

@ -137,9 +137,16 @@ namespace Nz
std::string targetName; std::string targetName;
}; };
struct StructData
{
std::string nameOverride;
const ShaderAst::StructDescription* desc;
};
std::optional<ShaderStageType> stage; std::optional<ShaderStageType> stage;
std::string moduleSuffix;
std::stringstream stream; std::stringstream stream;
std::unordered_map<std::size_t, ShaderAst::StructDescription*> structs; std::unordered_map<std::size_t, StructData> structs;
std::unordered_map<std::size_t, std::string> variableNames; std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields; std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields; std::vector<InOutField> outputFields;
@ -172,6 +179,7 @@ namespace Nz
else else
targetAst = module.rootNode.get(); targetAst = module.rootNode.get();
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
ShaderAst::StatementPtr optimizedAst; ShaderAst::StatementPtr optimizedAst;
if (states.optimize) if (states.optimize)
@ -193,6 +201,15 @@ namespace Nz
AppendHeader(); AppendHeader();
for (const auto& importedModule : targetModule.importedModules)
{
m_currentState->moduleSuffix = importedModule.identifier;
importedModule.module->rootNode->Visit(state.previsitor);
importedModule.module->rootNode->Visit(*this);
}
m_currentState->moduleSuffix = {};
targetAst->Visit(*this); targetAst->Visit(*this);
return state.stream.str(); return state.stream.str();
@ -342,8 +359,8 @@ namespace Nz
void GlslWriter::Append(const ShaderAst::StructType& structType) void GlslWriter::Append(const ShaderAst::StructType& structType)
{ {
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex); const auto& structData = Retrieve(m_currentState->structs, structType.structIndex);
Append(structDesc->name); Append(structData.nameOverride);
} }
void GlslWriter::Append(const ShaderAst::Type& /*type*/) void GlslWriter::Append(const ShaderAst::Type& /*type*/)
@ -630,9 +647,9 @@ namespace Nz
assert(IsStructType(parameter.type.GetResultingValue())); assert(IsStructType(parameter.type.GetResultingValue()));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex; std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex;
const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); const auto& structData = Retrieve(m_currentState->structs, structIndex);
AppendLine(structDesc->name, " ", varName, ";"); AppendLine(structData.nameOverride, " ", varName, ";");
for (const auto& [memberName, targetName] : m_currentState->inputFields) for (const auto& [memberName, targetName] : m_currentState->inputFields)
AppendLine(varName, ".", memberName, " = ", targetName, ";"); AppendLine(varName, ".", memberName, " = ", targetName, ";");
@ -651,9 +668,9 @@ namespace Nz
void GlslWriter::HandleInOut() void GlslWriter::HandleInOut()
{ {
auto AppendInOut = [this](const ShaderAst::StructDescription& structDesc, std::vector<State::InOutField>& fields, const char* keyword, const char* targetPrefix) auto AppendInOut = [this](const State::StructData& structData, std::vector<State::InOutField>& fields, const char* keyword, const char* targetPrefix)
{ {
for (const auto& member : structDesc.members) for (const auto& member : structData.desc->members)
{ {
if (member.cond.HasValue() && !member.cond.GetResultingValue()) if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue; continue;
@ -691,8 +708,6 @@ namespace Nz
const ShaderAst::DeclareFunctionStatement& node = *m_currentState->previsitor.entryPoint; const ShaderAst::DeclareFunctionStatement& node = *m_currentState->previsitor.entryPoint;
const ShaderAst::StructDescription* inputStruct = nullptr;
if (!node.parameters.empty()) if (!node.parameters.empty())
{ {
assert(node.parameters.size() == 1); assert(node.parameters.size() == 1);
@ -700,10 +715,10 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type.GetResultingValue())); assert(std::holds_alternative<ShaderAst::StructType>(parameter.type.GetResultingValue()));
std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex; std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex;
inputStruct = Retrieve(m_currentState->structs, inputStructIndex); const auto& inputStruct = Retrieve(m_currentState->structs, inputStructIndex);
AppendCommentSection("Inputs"); AppendCommentSection("Inputs");
AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix); AppendInOut(inputStruct, m_currentState->inputFields, "in", s_inputPrefix);
} }
if (m_currentState->stage == ShaderStageType::Vertex && m_environment.flipYPosition) if (m_currentState->stage == ShaderStageType::Vertex && m_environment.flipYPosition)
@ -717,17 +732,21 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType.GetResultingValue())); assert(std::holds_alternative<ShaderAst::StructType>(node.returnType.GetResultingValue()));
std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType.GetResultingValue()).structIndex; std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType.GetResultingValue()).structIndex;
const ShaderAst::StructDescription* outputStruct = Retrieve(m_currentState->structs, outputStructIndex); const auto& outputStruct = Retrieve(m_currentState->structs, outputStructIndex);
AppendCommentSection("Outputs"); AppendCommentSection("Outputs");
AppendInOut(*outputStruct, m_currentState->outputFields, "out", s_outputPrefix); AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix);
} }
} }
void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc) void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc, std::string structName)
{ {
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, desc); State::StructData structData;
structData.desc = desc;
structData.nameOverride = std::move(structName);
m_currentState->structs.emplace(structIndex, std::move(structData));
} }
void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName) void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName)
@ -1035,11 +1054,13 @@ namespace Nz
if (IsUniformType(externalVar.type.GetResultingValue())) if (IsUniformType(externalVar.type.GetResultingValue()))
{ {
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue()); auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue());
ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex); const auto& structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
if (structInfo->layout.HasValue()) if (structInfo.desc->layout.HasValue())
isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140; isStd140 = structInfo.desc->layout.GetResultingValue() == StructLayout::Std140;
} }
std::string varName = externalVar.name + m_currentState->moduleSuffix;
if (!m_currentState->bindingMapping.empty() || isStd140) if (!m_currentState->bindingMapping.empty() || isStd140)
Append("layout("); Append("layout(");
@ -1074,15 +1095,15 @@ namespace Nz
if (IsUniformType(externalVar.type.GetResultingValue())) if (IsUniformType(externalVar.type.GetResultingValue()))
{ {
Append("_NzBinding_"); Append("_NzBinding_");
AppendLine(externalVar.name); AppendLine(varName);
EnterScope(); EnterScope();
{ {
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue()); const auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue());
auto& structDesc = Retrieve(m_currentState->structs, uniform.containedType.structIndex); const auto& structData = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
bool first = true; bool first = true;
for (const auto& member : structDesc->members) for (const auto& member : structData.desc->members)
{ {
if (member.cond.HasValue() && !member.cond.GetResultingValue()) if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue; continue;
@ -1099,10 +1120,10 @@ namespace Nz
LeaveScope(false); LeaveScope(false);
Append(" "); Append(" ");
Append(externalVar.name); Append(varName);
} }
else else
AppendVariableDeclaration(externalVar.type.GetResultingValue(), externalVar.name); AppendVariableDeclaration(externalVar.type.GetResultingValue(), varName);
AppendLine(";"); AppendLine(";");
@ -1110,7 +1131,7 @@ namespace Nz
AppendLine(); AppendLine();
assert(externalVar.varIndex); assert(externalVar.varIndex);
RegisterVariable(*externalVar.varIndex, externalVar.name); RegisterVariable(*externalVar.varIndex, varName);
} }
} }
@ -1168,11 +1189,13 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{ {
std::string structName = node.description.name + m_currentState->moduleSuffix;
assert(node.structIndex); assert(node.structIndex);
RegisterStruct(*node.structIndex, &node.description); RegisterStruct(*node.structIndex, &node.description, structName);
Append("struct "); Append("struct ");
AppendLine(node.description.name); AppendLine(structName);
EnterScope(); EnterScope();
{ {
bool first = true; bool first = true;
@ -1224,9 +1247,9 @@ namespace Nz
Append(";"); Append(";");
} }
void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/) void GlslWriter::Visit(ShaderAst::ImportStatement& node)
{ {
/* nothing to do */ throw std::runtime_error("unexpected import statement, is the shader sanitized properly?");
} }
void GlslWriter::Visit(ShaderAst::MultiStatement& node) void GlslWriter::Visit(ShaderAst::MultiStatement& node)
@ -1254,7 +1277,7 @@ namespace Nz
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr); const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
assert(IsStructType(returnType)); assert(IsStructType(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex; std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); const auto& structData = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName; std::string outputStructVarName;
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression) if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)
@ -1262,7 +1285,7 @@ namespace Nz
else else
{ {
AppendLine(); AppendLine();
Append(structDesc->name, " ", s_outputVarName, " = "); Append(structData.nameOverride, " ", s_outputVarName, " = ");
node.returnExpr->Visit(*this); node.returnExpr->Visit(*this);
AppendLine(";"); AppendLine(";");

View File

@ -31,73 +31,95 @@ namespace Nz
{ {
const ShaderAst::ExpressionValue<UInt32>& bindingIndex; const ShaderAst::ExpressionValue<UInt32>& bindingIndex;
inline bool HasValue() const { return bindingIndex.HasValue(); } bool HasValue() const { return bindingIndex.HasValue(); }
}; };
struct LangWriter::BuiltinAttribute struct LangWriter::BuiltinAttribute
{ {
const ShaderAst::ExpressionValue<ShaderAst::BuiltinEntry>& builtin; const ShaderAst::ExpressionValue<ShaderAst::BuiltinEntry>& builtin;
inline bool HasValue() const { return builtin.HasValue(); } bool HasValue() const { return builtin.HasValue(); }
}; };
struct LangWriter::DepthWriteAttribute struct LangWriter::DepthWriteAttribute
{ {
const ShaderAst::ExpressionValue<ShaderAst::DepthWriteMode>& writeMode; const ShaderAst::ExpressionValue<ShaderAst::DepthWriteMode>& writeMode;
inline bool HasValue() const { return writeMode.HasValue(); } bool HasValue() const { return writeMode.HasValue(); }
}; };
struct LangWriter::EarlyFragmentTestsAttribute struct LangWriter::EarlyFragmentTestsAttribute
{ {
const ShaderAst::ExpressionValue<bool>& earlyFragmentTests; const ShaderAst::ExpressionValue<bool>& earlyFragmentTests;
inline bool HasValue() const { return earlyFragmentTests.HasValue(); } bool HasValue() const { return earlyFragmentTests.HasValue(); }
}; };
struct LangWriter::EntryAttribute struct LangWriter::EntryAttribute
{ {
const ShaderAst::ExpressionValue<ShaderStageType>& stageType; const ShaderAst::ExpressionValue<ShaderStageType>& stageType;
inline bool HasValue() const { return stageType.HasValue(); } bool HasValue() const { return stageType.HasValue(); }
};
struct LangWriter::LangVersionAttribute
{
UInt32 version;
bool HasValue() const { return true; }
}; };
struct LangWriter::LayoutAttribute struct LangWriter::LayoutAttribute
{ {
const ShaderAst::ExpressionValue<StructLayout>& layout; const ShaderAst::ExpressionValue<StructLayout>& layout;
inline bool HasValue() const { return layout.HasValue(); } bool HasValue() const { return layout.HasValue(); }
}; };
struct LangWriter::LocationAttribute struct LangWriter::LocationAttribute
{ {
const ShaderAst::ExpressionValue<UInt32>& locationIndex; const ShaderAst::ExpressionValue<UInt32>& locationIndex;
inline bool HasValue() const { return locationIndex.HasValue(); } bool HasValue() const { return locationIndex.HasValue(); }
}; };
struct LangWriter::SetAttribute struct LangWriter::SetAttribute
{ {
const ShaderAst::ExpressionValue<UInt32>& setIndex; const ShaderAst::ExpressionValue<UInt32>& setIndex;
inline bool HasValue() const { return setIndex.HasValue(); } bool HasValue() const { return setIndex.HasValue(); }
}; };
struct LangWriter::UnrollAttribute struct LangWriter::UnrollAttribute
{ {
const ShaderAst::ExpressionValue<ShaderAst::LoopUnroll>& unroll; const ShaderAst::ExpressionValue<ShaderAst::LoopUnroll>& unroll;
inline bool HasValue() const { return unroll.HasValue(); } bool HasValue() const { return unroll.HasValue(); }
};
struct LangWriter::UuidAttribute
{
Uuid uuid;
bool HasValue() const { return true; }
}; };
struct LangWriter::State struct LangWriter::State
{ {
struct Identifier
{
std::size_t moduleIndex;
std::string name;
};
const States* states = nullptr; const States* states = nullptr;
ShaderAst::Module* module; ShaderAst::Module* module;
std::size_t currentModuleIndex;
std::stringstream stream; std::stringstream stream;
std::unordered_map<std::size_t, std::string> constantNames; std::unordered_map<std::size_t, Identifier> constantNames;
std::unordered_map<std::size_t, ShaderAst::StructDescription*> structs; std::unordered_map<std::size_t, Identifier> structs;
std::unordered_map<std::size_t, std::string> variableNames; std::unordered_map<std::size_t, Identifier> variableNames;
std::vector<std::string> moduleNames;
bool isInEntryPoint = false; bool isInEntryPoint = false;
unsigned int indentLevel = 0; unsigned int indentLevel = 0;
}; };
@ -116,6 +138,22 @@ namespace Nz
AppendHeader(); AppendHeader();
// Register imported modules
m_currentState->currentModuleIndex = 0;
for (const auto& importedModule : sanitizedModule->importedModules)
{
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
AppendLine("module ", importedModule.identifier);
EnterScope();
importedModule.module->rootNode->Visit(*this);
LeaveScope(true);
m_currentState->currentModuleIndex++;
m_currentState->moduleNames.push_back(importedModule.identifier);
}
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
sanitizedModule->rootNode->Visit(*this); sanitizedModule->rootNode->Visit(*this);
return state.stream.str(); return state.stream.str();
@ -213,8 +251,7 @@ namespace Nz
void LangWriter::Append(const ShaderAst::StructType& structType) void LangWriter::Append(const ShaderAst::StructType& structType)
{ {
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex); AppendIdentifier(m_currentState->structs, structType.structIndex);
Append(structDesc->name);
} }
void LangWriter::Append(const ShaderAst::Type& /*type*/) void LangWriter::Append(const ShaderAst::Type& /*type*/)
@ -292,31 +329,31 @@ namespace Nz
AppendAttributesInternal(first, secondParam, std::forward<Rest>(params)...); AppendAttributesInternal(first, secondParam, std::forward<Rest>(params)...);
} }
void LangWriter::AppendAttribute(BindingAttribute binding) void LangWriter::AppendAttribute(BindingAttribute attribute)
{ {
if (!binding.HasValue()) if (!attribute.HasValue())
return; return;
Append("binding("); Append("binding(");
if (binding.bindingIndex.IsResultingValue()) if (attribute.bindingIndex.IsResultingValue())
Append(binding.bindingIndex.GetResultingValue()); Append(attribute.bindingIndex.GetResultingValue());
else else
binding.bindingIndex.GetExpression()->Visit(*this); attribute.bindingIndex.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(BuiltinAttribute builtin) void LangWriter::AppendAttribute(BuiltinAttribute attribute)
{ {
if (!builtin.HasValue()) if (!attribute.HasValue())
return; return;
Append("builtin("); Append("builtin(");
if (builtin.builtin.IsResultingValue()) if (attribute.builtin.IsResultingValue())
{ {
switch (builtin.builtin.GetResultingValue()) switch (attribute.builtin.GetResultingValue())
{ {
case ShaderAst::BuiltinEntry::FragCoord: case ShaderAst::BuiltinEntry::FragCoord:
Append("fragcoord"); Append("fragcoord");
@ -332,21 +369,21 @@ namespace Nz
} }
} }
else else
builtin.builtin.GetExpression()->Visit(*this); attribute.builtin.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite) void LangWriter::AppendAttribute(DepthWriteAttribute attribute)
{ {
if (!depthWrite.HasValue()) if (!attribute.HasValue())
return; return;
Append("depth_write("); Append("depth_write(");
if (depthWrite.writeMode.IsResultingValue()) if (attribute.writeMode.IsResultingValue())
{ {
switch (depthWrite.writeMode.GetResultingValue()) switch (attribute.writeMode.GetResultingValue())
{ {
case ShaderAst::DepthWriteMode::Greater: case ShaderAst::DepthWriteMode::Greater:
Append("greater"); Append("greater");
@ -366,41 +403,41 @@ namespace Nz
} }
} }
else else
depthWrite.writeMode.GetExpression()->Visit(*this); attribute.writeMode.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests) void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute attribute)
{ {
if (!earlyFragmentTests.HasValue()) if (!attribute.HasValue())
return; return;
Append("early_fragment_tests("); Append("early_fragment_tests(");
if (earlyFragmentTests.earlyFragmentTests.IsResultingValue()) if (attribute.earlyFragmentTests.IsResultingValue())
{ {
if (earlyFragmentTests.earlyFragmentTests.GetResultingValue()) if (attribute.earlyFragmentTests.GetResultingValue())
Append("true"); Append("true");
else else
Append("false"); Append("false");
} }
else else
earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this); attribute.earlyFragmentTests.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(EntryAttribute entry) void LangWriter::AppendAttribute(EntryAttribute attribute)
{ {
if (!entry.HasValue()) if (!attribute.HasValue())
return; return;
Append("entry("); Append("entry(");
if (entry.stageType.IsResultingValue()) if (attribute.stageType.IsResultingValue())
{ {
switch (entry.stageType.GetResultingValue()) switch (attribute.stageType.GetResultingValue())
{ {
case ShaderStageType::Fragment: case ShaderStageType::Fragment:
Append("frag"); Append("frag");
@ -412,20 +449,39 @@ namespace Nz
} }
} }
else else
entry.stageType.GetExpression()->Visit(*this); attribute.stageType.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(LayoutAttribute entry) void LangWriter::AppendAttribute(LangVersionAttribute attribute)
{ {
if (!entry.HasValue()) UInt32 shaderLangVersion = attribute.version;
UInt32 majorVersion = shaderLangVersion / 100;
shaderLangVersion -= majorVersion * 100;
UInt32 minorVersion = shaderLangVersion / 10;
shaderLangVersion -= minorVersion * 100;
UInt32 patchVersion = shaderLangVersion;
// nzsl_version
Append("nzsl_version(\"", majorVersion, ".", minorVersion);
if (patchVersion != 0)
Append(".", patchVersion);
Append("\")");
}
void LangWriter::AppendAttribute(LayoutAttribute attribute)
{
if (!attribute.HasValue())
return; return;
Append("layout("); Append("layout(");
if (entry.layout.IsResultingValue()) if (attribute.layout.IsResultingValue())
{ {
switch (entry.layout.GetResultingValue()) switch (attribute.layout.GetResultingValue())
{ {
case StructLayout::Packed: case StructLayout::Packed:
Append("packed"); Append("packed");
@ -437,50 +493,50 @@ namespace Nz
} }
} }
else else
entry.layout.GetExpression()->Visit(*this); attribute.layout.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(LocationAttribute location) void LangWriter::AppendAttribute(LocationAttribute attribute)
{ {
if (!location.HasValue()) if (!attribute.HasValue())
return; return;
Append("location("); Append("location(");
if (location.locationIndex.IsResultingValue()) if (attribute.locationIndex.IsResultingValue())
Append(location.locationIndex.GetResultingValue()); Append(attribute.locationIndex.GetResultingValue());
else else
location.locationIndex.GetExpression()->Visit(*this); attribute.locationIndex.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(SetAttribute set) void LangWriter::AppendAttribute(SetAttribute attribute)
{ {
if (!set.HasValue()) if (!attribute.HasValue())
return; return;
Append("set("); Append("set(");
if (set.setIndex.IsResultingValue()) if (attribute.setIndex.IsResultingValue())
Append(set.setIndex.GetResultingValue()); Append(attribute.setIndex.GetResultingValue());
else else
set.setIndex.GetExpression()->Visit(*this); attribute.setIndex.GetExpression()->Visit(*this);
Append(")"); Append(")");
} }
void LangWriter::AppendAttribute(UnrollAttribute unroll) void LangWriter::AppendAttribute(UnrollAttribute attribute)
{ {
if (!unroll.HasValue()) if (!attribute.HasValue())
return; return;
Append("unroll("); Append("unroll(");
if (unroll.unroll.IsResultingValue()) if (attribute.unroll.IsResultingValue())
{ {
switch (unroll.unroll.GetResultingValue()) switch (attribute.unroll.GetResultingValue())
{ {
case ShaderAst::LoopUnroll::Always: case ShaderAst::LoopUnroll::Always:
Append("always"); Append("always");
@ -499,7 +555,12 @@ namespace Nz
} }
} }
else else
unroll.unroll.GetExpression()->Visit(*this); attribute.unroll.GetExpression()->Visit(*this);
}
void LangWriter::AppendAttribute(UuidAttribute attribute)
{
Append("uuid(\"", attribute.uuid.ToString(), "\")");
} }
void LangWriter::AppendComment(const std::string& section) void LangWriter::AppendComment(const std::string& section)
@ -539,6 +600,16 @@ namespace Nz
m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t'); m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t');
} }
template<typename T>
void LangWriter::AppendIdentifier(const T& map, std::size_t id)
{
const auto& structIdentifier = Retrieve(map, id);
if (structIdentifier.moduleIndex != m_currentState->currentModuleIndex)
Append(m_currentState->moduleNames[structIdentifier.moduleIndex], '.');
Append(structIdentifier.name);
}
template<typename... Args> template<typename... Args>
void LangWriter::AppendLine(Args&&... params) void LangWriter::AppendLine(Args&&... params)
{ {
@ -586,20 +657,32 @@ namespace Nz
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName) void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
{ {
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(constantName);
assert(m_currentState->constantNames.find(constantIndex) == m_currentState->constantNames.end()); assert(m_currentState->constantNames.find(constantIndex) == m_currentState->constantNames.end());
m_currentState->constantNames.emplace(constantIndex, std::move(constantName)); m_currentState->constantNames.emplace(constantIndex, std::move(identifier));
} }
void LangWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc) void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
{ {
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(structName);
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, desc); m_currentState->structs.emplace(structIndex, std::move(identifier));
} }
void LangWriter::RegisterVariable(std::size_t varIndex, std::string varName) void LangWriter::RegisterVariable(std::size_t varIndex, std::string varName)
{ {
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(varName);
assert(m_currentState->variableNames.find(varIndex) == m_currentState->variableNames.end()); assert(m_currentState->variableNames.find(varIndex) == m_currentState->variableNames.end());
m_currentState->variableNames.emplace(varIndex, std::move(varName)); m_currentState->variableNames.emplace(varIndex, std::move(identifier));
} }
void LangWriter::ScopeVisit(ShaderAst::Statement& node) void LangWriter::ScopeVisit(ShaderAst::Statement& node)
@ -770,7 +853,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node) void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{ {
throw std::runtime_error("TODO"); //< missing registering //throw std::runtime_error("TODO"); //< missing registering
assert(node.aliasIndex); assert(node.aliasIndex);
@ -828,7 +911,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ConstantExpression& node) void LangWriter::Visit(ShaderAst::ConstantExpression& node)
{ {
Append(Retrieve(m_currentState->constantNames, node.constantId)); AppendIdentifier(m_currentState->constantNames, node.constantId);
} }
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node) void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
@ -908,7 +991,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node) void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
{ {
assert(node.structIndex); assert(node.structIndex);
RegisterStruct(*node.structIndex, &node.description); RegisterStruct(*node.structIndex, node.description.name);
AppendAttributes(true, LayoutAttribute{ node.description.layout }); AppendAttributes(true, LayoutAttribute{ node.description.layout });
Append("struct "); Append("struct ");
@ -1068,6 +1151,11 @@ namespace Nz
Append(")"); Append(")");
} }
void LangWriter::Visit(ShaderAst::StructTypeExpression& node)
{
AppendIdentifier(m_currentState->structs, node.structTypeId);
}
void LangWriter::Visit(ShaderAst::MultiStatement& node) void LangWriter::Visit(ShaderAst::MultiStatement& node)
{ {
if (!node.sectionName.empty()) if (!node.sectionName.empty())
@ -1100,7 +1188,7 @@ namespace Nz
{ {
EnterScope(); EnterScope();
node.statement->Visit(*this); node.statement->Visit(*this);
LeaveScope(true); LeaveScope();
} }
void LangWriter::Visit(ShaderAst::SwizzleExpression& node) void LangWriter::Visit(ShaderAst::SwizzleExpression& node)
@ -1115,8 +1203,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::VariableExpression& node) void LangWriter::Visit(ShaderAst::VariableExpression& node)
{ {
const std::string& varName = Retrieve(m_currentState->variableNames, node.variableId); AppendIdentifier(m_currentState->variableNames, node.variableId);
Append(varName);
} }
void LangWriter::Visit(ShaderAst::UnaryExpression& node) void LangWriter::Visit(ShaderAst::UnaryExpression& node)
@ -1150,22 +1237,7 @@ namespace Nz
void LangWriter::AppendHeader() void LangWriter::AppendHeader()
{ {
UInt32 shaderLangVersion = m_currentState->module->metadata->shaderLangVersion; AppendAttributes(true, LangVersionAttribute{ m_currentState->module->metadata->shaderLangVersion });
UInt32 majorVersion = shaderLangVersion / 100;
shaderLangVersion -= majorVersion * 100;
UInt32 minorVersion = shaderLangVersion / 10;
shaderLangVersion -= minorVersion * 100;
UInt32 patchVersion = shaderLangVersion;
// nzsl_version
Append("[nzsl_version(\"", majorVersion, ".", minorVersion);
if (patchVersion != 0)
Append(".", patchVersion);
AppendLine("\")]");
AppendLine("module;"); AppendLine("module;");
AppendLine(); AppendLine();
} }

View File

@ -40,6 +40,7 @@ namespace Nz::ShaderLang
ForceCLocale forceCLocale; ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = { std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "alias", TokenType::Alias },
{ "const", TokenType::Const }, { "const", TokenType::Const },
{ "const_select", TokenType::ConstSelect }, { "const_select", TokenType::ConstSelect },
{ "discard", TokenType::Discard }, { "discard", TokenType::Discard },

View File

@ -246,6 +246,9 @@ namespace Nz::ShaderLang
void Parser::ParseModuleStatement(std::vector<ShaderAst::ExprValue> attributes) void Parser::ParseModuleStatement(std::vector<ShaderAst::ExprValue> attributes)
{ {
if (m_context->parsingImportedModule)
throw UnexpectedToken{};
Expect(Advance(), TokenType::Module); Expect(Advance(), TokenType::Module);
std::optional<UInt32> moduleVersion; std::optional<UInt32> moduleVersion;
@ -343,7 +346,21 @@ namespace Nz::ShaderLang
const std::string& identifier = std::get<std::string>(Peek().data); const std::string& identifier = std::get<std::string>(Peek().data);
Consume(); Consume();
module->rootNode->statements = ParseStatementList(); Expect(Advance(), TokenType::OpenCurlyBracket);
m_context->parsingImportedModule = true;
while (Peek().type != TokenType::ClosingCurlyBracket)
{
ShaderAst::StatementPtr statement = ParseRootStatement();
if (!statement)
throw UnexpectedToken{}; //< "unexpected end of file"
module->rootNode->statements.push_back(std::move(statement));
}
Consume(); //< Consume ClosingCurlyBracket
m_context->parsingImportedModule = false;
auto& importedModule = m_context->module->importedModules.emplace_back(); auto& importedModule = m_context->module->importedModules.emplace_back();
importedModule.module = std::move(module); importedModule.module = std::move(module);
@ -380,6 +397,21 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Semicolon); Expect(Advance(), TokenType::Semicolon);
} }
ShaderAst::StatementPtr Parser::ParseAliasDeclaration()
{
Expect(Advance(), TokenType::Alias);
std::string name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Assign);
ShaderAst::ExpressionPtr expr = ParseExpression();
Expect(Advance(), TokenType::Semicolon);
return ShaderBuilder::DeclareAlias(std::move(name), std::move(expr));
}
ShaderAst::StatementPtr Parser::ParseBranchStatement() ShaderAst::StatementPtr Parser::ParseBranchStatement()
{ {
std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>(); std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>();
@ -756,6 +788,12 @@ namespace Nz::ShaderLang
const Token& nextToken = Peek(); const Token& nextToken = Peek();
switch (nextToken.type) switch (nextToken.type)
{ {
case TokenType::Alias:
if (!attributes.empty())
throw UnexpectedToken{};
return ParseAliasDeclaration();
case TokenType::Const: case TokenType::Const:
if (!attributes.empty()) if (!attributes.empty())
throw UnexpectedToken{}; throw UnexpectedToken{};

View File

@ -517,6 +517,8 @@ namespace Nz
else else
targetAst = module.rootNode.get(); targetAst = module.rootNode.get();
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
ShaderAst::StatementPtr optimizedAst; ShaderAst::StatementPtr optimizedAst;
if (states.optimize) if (states.optimize)
{ {
@ -542,6 +544,9 @@ namespace Nz
// Register all extended instruction sets // Register all extended instruction sets
PreVisitor preVisitor(state.constantTypeCache, state.funcs); PreVisitor preVisitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
importedModule.module->rootNode->Visit(preVisitor);
targetAst->Visit(preVisitor); targetAst->Visit(preVisitor);
m_currentState->preVisitor = &preVisitor; m_currentState->preVisitor = &preVisitor;
@ -554,6 +559,9 @@ namespace Nz
func.funcId = AllocateResultId(); func.funcId = AllocateResultId();
SpirvAstVisitor visitor(*this, state.instructions, state.funcs); SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
importedModule.module->rootNode->Visit(visitor);
targetAst->Visit(visitor); targetAst->Visit(visitor);
AppendHeader(); AppendHeader();