Modules are workings \o/

This commit is contained in:
Jérôme Leclercq 2022-03-08 20:26:02 +01:00
parent 83d26e209e
commit be9bdc4705
29 changed files with 742 additions and 256 deletions

View File

@ -9,16 +9,26 @@
NAZARA_REQUEST_DEDICATED_GPU()
const char moduleSource[] = R"(
const char barModuleSource[] = R"(
[nzsl_version("1.0")]
[uuid("4BB09DEE-F70A-442E-859F-E8F2F3F8583D")]
module;
fn dummy() {}
[export]
[layout(std140)]
struct Bar
{
}
)";
const char dataModuleSource[] = R"(
[nzsl_version("1.0")]
[uuid("E49DC9AD-469C-462C-9719-A6F012372029")]
module;
import Test/Bar;
struct Foo {}
@ -37,14 +47,15 @@ const char shaderSource[] = R"(
[nzsl_version("1.0")]
module;
import Test/module_test;
import Test/Data;
import Test/Bar;
option red: bool = false;
[set(0)]
external
{
[binding(0)] viewerData: uniform[Data],
[binding(0)] viewerData: uniform[Data]
}
[set(1)]
@ -138,10 +149,12 @@ int main()
if (modulePath[0] != "Test")
return {};
if (modulePath[1] != "module_test")
if (modulePath[1] == "Bar")
return Nz::ShaderLang::Parse(std::string_view(barModuleSource, sizeof(barModuleSource)));
else if (modulePath[1] == "Data")
return Nz::ShaderLang::Parse(std::string_view(dataModuleSource, sizeof(dataModuleSource)));
else
return {};
return Nz::ShaderLang::Parse(std::string_view(moduleSource, sizeof(moduleSource)));
};
shaderModule = Nz::ShaderAst::Sanitize(*shaderModule, sanitizeOpt);
@ -152,10 +165,12 @@ int main()
}
Nz::LangWriter langWriter;
std::cout << langWriter.Generate(*shaderModule) << std::endl;
std::string output = langWriter.Generate(*shaderModule);
std::cout << output << std::endl;
assert(Nz::ShaderAst::Sanitize(*Nz::ShaderLang::Parse(output)));
Nz::ShaderWriter::States states;
states.optimize = true;
states.optimize = false;
auto fragVertShader = device->InstantiateShaderModule(Nz::ShaderStageType::Fragment | Nz::ShaderStageType::Vertex, *shaderModule, states);
if (!fragVertShader)

View File

@ -50,8 +50,11 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(ConditionalExpression& node);
virtual ExpressionPtr Clone(ConstantExpression& node);
virtual ExpressionPtr Clone(ConstantValueExpression& node);
virtual ExpressionPtr Clone(FunctionExpression& node);
virtual ExpressionPtr Clone(IdentifierExpression& node);
virtual ExpressionPtr Clone(IntrinsicExpression& node);
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
virtual ExpressionPtr Clone(StructTypeExpression& node);
virtual ExpressionPtr Clone(SwizzleExpression& node);
virtual ExpressionPtr Clone(VariableExpression& node);
virtual ExpressionPtr Clone(UnaryExpression& node);

View File

@ -42,8 +42,11 @@ namespace Nz::ShaderAst
inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs);
inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs);
inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs);
inline bool Compare(const FunctionExpression& lhs, const FunctionExpression& rhs);
inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs);
inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs);
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs);
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);

View File

@ -342,6 +342,14 @@ namespace Nz::ShaderAst
return true;
}
inline bool Compare(const FunctionExpression& lhs, const FunctionExpression& rhs)
{
if (!Compare(lhs.funcId, rhs.funcId))
return false;
return true;
}
inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs)
{
if (!Compare(lhs.identifier, rhs.identifier))
@ -361,6 +369,22 @@ namespace Nz::ShaderAst
return true;
}
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs)
{
if (!Compare(lhs.intrinsicId, rhs.intrinsicId))
return false;
return true;
}
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs)
{
if (!Compare(lhs.structTypeId, rhs.structTypeId))
return false;
return true;
}
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs)
{
if (!Compare(lhs.componentCount, rhs.componentCount))

View File

@ -38,8 +38,11 @@ NAZARA_SHADERAST_EXPRESSION(CastExpression)
NAZARA_SHADERAST_EXPRESSION(ConditionalExpression)
NAZARA_SHADERAST_EXPRESSION(ConstantExpression)
NAZARA_SHADERAST_EXPRESSION(ConstantValueExpression)
NAZARA_SHADERAST_EXPRESSION(FunctionExpression)
NAZARA_SHADERAST_EXPRESSION(IdentifierExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_EXPRESSION(VariableExpression)
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)

View File

@ -30,8 +30,11 @@ namespace Nz::ShaderAst
void Visit(ConditionalExpression& node) override;
void Visit(ConstantValueExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(FunctionExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(VariableExpression& node) override;
void Visit(UnaryExpression& node) override;

View File

@ -33,8 +33,11 @@ namespace Nz::ShaderAst
void Serialize(ConstantExpression& node);
void Serialize(ConditionalExpression& node);
void Serialize(ConstantValueExpression& node);
void Serialize(FunctionExpression& node);
void Serialize(IdentifierExpression& node);
void Serialize(IntrinsicExpression& node);
void Serialize(IntrinsicFunctionExpression& node);
void Serialize(StructTypeExpression& node);
void Serialize(SwizzleExpression& node);
void Serialize(VariableExpression& node);
void Serialize(UnaryExpression& node);

View File

@ -41,8 +41,11 @@ namespace Nz::ShaderAst
void Visit(ConditionalExpression& node) override;
void Visit(ConstantValueExpression& node) override;
void Visit(ConstantExpression& node) override;
void Visit(FunctionExpression& node) override;
void Visit(IdentifierExpression& node) override;
void Visit(IntrinsicExpression& node) override;
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(VariableExpression& node) override;
void Visit(UnaryExpression& node) override;

View File

@ -16,6 +16,9 @@ namespace Nz::ShaderAst
inline ModulePtr EliminateUnusedPass(const Module& shaderModule, const DependencyCheckerVisitor::Config& config)
{
DependencyCheckerVisitor dependencyVisitor;
for (const auto& importedModule : shaderModule.importedModules)
dependencyVisitor.Process(*importedModule.module->rootNode, config);
dependencyVisitor.Process(*shaderModule.rootNode, config);
dependencyVisitor.Resolve();

View File

@ -45,7 +45,8 @@ namespace Nz::ShaderAst
StatementPtr Clone(DeclareStructStatement& node) override;
StatementPtr Clone(DeclareVariableStatement& node) override;
ExpressionPtr Clone(CallFunctionExpression& node) override;
ExpressionPtr Clone(FunctionExpression& node) override;
ExpressionPtr Clone(StructTypeExpression& node) override;
ExpressionPtr Clone(VariableExpression& node) override;
void HandleType(ExpressionValue<ExpressionType>& exprType);

View File

@ -156,6 +156,14 @@ namespace Nz::ShaderAst
ShaderAst::ConstantValue value;
};
struct NAZARA_SHADER_API FunctionExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t funcId;
};
struct NAZARA_SHADER_API IdentifierExpression : Expression
{
NodeType GetType() const override;
@ -173,6 +181,22 @@ namespace Nz::ShaderAst
IntrinsicType intrinsic;
};
struct NAZARA_SHADER_API IntrinsicFunctionExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t intrinsicId;
};
struct NAZARA_SHADER_API StructTypeExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t structTypeId;
};
struct NAZARA_SHADER_API SwizzleExpression : Expression
{
NodeType GetType() const override;

View File

@ -108,6 +108,8 @@ namespace Nz::ShaderAst
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
Expression& MandatoryExpr(const ExpressionPtr& node) const;
Statement& MandatoryStatement(const StatementPtr& node) const;

View File

@ -85,7 +85,7 @@ namespace Nz
void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node);
void HandleInOut();
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc);
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName);
void ScopeVisit(ShaderAst::Statement& node);

View File

@ -43,10 +43,12 @@ namespace Nz
struct DepthWriteAttribute;
struct EarlyFragmentTestsAttribute;
struct EntryAttribute;
struct LangVersionAttribute;
struct LayoutAttribute;
struct LocationAttribute;
struct SetAttribute;
struct UnrollAttribute;
struct UuidAttribute;
void Append(const ShaderAst::ArrayType& type);
void Append(const ShaderAst::ExpressionType& type);
@ -68,18 +70,21 @@ namespace Nz
template<typename... Args> void AppendAttributes(bool appendLine, Args&&... params);
template<typename T> void AppendAttributesInternal(bool& first, const T& param);
template<typename T1, typename T2, typename... Rest> void AppendAttributesInternal(bool& first, const T1& firstParam, const T2& secondParam, Rest&&... params);
void AppendAttribute(BindingAttribute binding);
void AppendAttribute(BuiltinAttribute builtin);
void AppendAttribute(DepthWriteAttribute depthWrite);
void AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests);
void AppendAttribute(EntryAttribute entry);
void AppendAttribute(LayoutAttribute layout);
void AppendAttribute(LocationAttribute location);
void AppendAttribute(SetAttribute set);
void AppendAttribute(UnrollAttribute unroll);
void AppendAttribute(BindingAttribute attribute);
void AppendAttribute(BuiltinAttribute attribute);
void AppendAttribute(DepthWriteAttribute attribute);
void AppendAttribute(EarlyFragmentTestsAttribute attribute);
void AppendAttribute(EntryAttribute attribute);
void AppendAttribute(LangVersionAttribute attribute);
void AppendAttribute(LayoutAttribute attribute);
void AppendAttribute(LocationAttribute attribute);
void AppendAttribute(SetAttribute seattributet);
void AppendAttribute(UnrollAttribute attribute);
void AppendAttribute(UuidAttribute attribute);
void AppendComment(const std::string& section);
void AppendCommentSection(const std::string& section);
void AppendHeader();
template<typename T> void AppendIdentifier(const T& map, std::size_t id);
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);
void AppendStatementList(std::vector<ShaderAst::StatementPtr>& statements);
@ -88,7 +93,7 @@ namespace Nz
void LeaveScope(bool skipLine = true);
void RegisterConstant(std::size_t constantIndex, std::string constantName);
void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc);
void RegisterStruct(std::size_t structIndex, std::string structName);
void RegisterVariable(std::size_t varIndex, std::string varName);
void ScopeVisit(ShaderAst::Statement& node);
@ -104,6 +109,7 @@ namespace Nz
void Visit(ShaderAst::ConstantValueExpression& node) override;
void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::StructTypeExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
void Visit(ShaderAst::UnaryExpression& node) override;

View File

@ -127,6 +127,11 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ForEachStatement> operator()(std::string varName, ShaderAst::ExpressionPtr expression, ShaderAst::StatementPtr statement) const;
};
struct Function
{
inline std::unique_ptr<ShaderAst::FunctionExpression> operator()(std::size_t funcId) const;
};
struct Identifier
{
inline std::unique_ptr<ShaderAst::IdentifierExpression> operator()(std::string name) const;
@ -142,6 +147,11 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::IntrinsicExpression> operator()(ShaderAst::IntrinsicType intrinsicType, std::vector<ShaderAst::ExpressionPtr> parameters) const;
};
struct IntrinsicFunction
{
inline std::unique_ptr<ShaderAst::IntrinsicFunctionExpression> operator()(std::size_t intrinsicFunctionId, ShaderAst::IntrinsicType intrinsicType) const;
};
struct Multi
{
inline std::unique_ptr<ShaderAst::MultiStatement> operator()(std::vector<ShaderAst::StatementPtr> statements = {}) const;
@ -163,6 +173,11 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ScopedStatement> operator()(ShaderAst::StatementPtr statement) const;
};
struct StructType
{
inline std::unique_ptr<ShaderAst::StructTypeExpression> operator()(std::size_t structTypeId) const;
};
struct Swizzle
{
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const;
@ -206,13 +221,16 @@ namespace Nz::ShaderBuilder
constexpr Impl::NoParam<ShaderAst::DiscardStatement> Discard;
constexpr Impl::For For;
constexpr Impl::ForEach ForEach;
constexpr Impl::Function Function;
constexpr Impl::Identifier Identifier;
constexpr Impl::IntrinsicFunction IntrinsicFunction;
constexpr Impl::Import Import;
constexpr Impl::Intrinsic Intrinsic;
constexpr Impl::Multi MultiStatement;
constexpr Impl::NoParam<ShaderAst::NoOpStatement> NoOp;
constexpr Impl::Return Return;
constexpr Impl::Scoped Scoped;
constexpr Impl::StructType StructType;
constexpr Impl::Swizzle Swizzle;
constexpr Impl::Unary Unary;
constexpr Impl::Variable Variable;

View File

@ -315,7 +315,7 @@ namespace Nz::ShaderBuilder
return expressionStatementNode;
}
inline std::unique_ptr<ShaderAst::ForStatement> Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const
inline std::unique_ptr<ShaderAst::ForStatement> Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::StatementPtr statement) const
{
auto forNode = std::make_unique<ShaderAst::ForStatement>();
forNode->fromExpr = std::move(fromExpression);
@ -326,7 +326,7 @@ namespace Nz::ShaderBuilder
return forNode;
}
inline std::unique_ptr<ShaderAst::ForStatement> Nz::ShaderBuilder::Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const
inline std::unique_ptr<ShaderAst::ForStatement> Impl::For::operator()(std::string varName, ShaderAst::ExpressionPtr fromExpression, ShaderAst::ExpressionPtr toExpression, ShaderAst::ExpressionPtr stepExpression, ShaderAst::StatementPtr statement) const
{
auto forNode = std::make_unique<ShaderAst::ForStatement>();
forNode->fromExpr = std::move(fromExpression);
@ -348,6 +348,15 @@ namespace Nz::ShaderBuilder
return forEachNode;
}
inline std::unique_ptr<ShaderAst::FunctionExpression> Impl::Function::operator()(std::size_t funcId) const
{
auto intrinsicTypeExpr = std::make_unique<ShaderAst::FunctionExpression>();
intrinsicTypeExpr->cachedExpressionType = ShaderAst::FunctionType{ funcId };
intrinsicTypeExpr->funcId = funcId;
return intrinsicTypeExpr;
}
inline std::unique_ptr<ShaderAst::IdentifierExpression> Impl::Identifier::operator()(std::string name) const
{
auto identifierNode = std::make_unique<ShaderAst::IdentifierExpression>();
@ -373,6 +382,15 @@ namespace Nz::ShaderBuilder
return intrinsicExpression;
}
inline std::unique_ptr<ShaderAst::IntrinsicFunctionExpression> Impl::IntrinsicFunction::operator()(std::size_t intrinsicFunctionId, ShaderAst::IntrinsicType intrinsicType) const
{
auto intrinsicTypeExpr = std::make_unique<ShaderAst::IntrinsicFunctionExpression>();
intrinsicTypeExpr->cachedExpressionType = ShaderAst::IntrinsicFunctionType{ intrinsicType };
intrinsicTypeExpr->intrinsicId = intrinsicFunctionId;
return intrinsicTypeExpr;
}
inline std::unique_ptr<ShaderAst::MultiStatement> Impl::Multi::operator()(std::vector<ShaderAst::StatementPtr> statements) const
{
auto multiStatement = std::make_unique<ShaderAst::MultiStatement>();
@ -403,6 +421,15 @@ namespace Nz::ShaderBuilder
return scopedNode;
}
inline std::unique_ptr<ShaderAst::StructTypeExpression> Impl::StructType::operator()(std::size_t structTypeId) const
{
auto structTypeExpr = std::make_unique<ShaderAst::StructTypeExpression>();
structTypeExpr->cachedExpressionType = ShaderAst::StructType{ structTypeId };
structTypeExpr->structTypeId = structTypeId;
return structTypeExpr;
}
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::array<UInt32, 4> swizzleComponents, std::size_t componentCount) const
{
assert(componentCount > 0);

View File

@ -86,6 +86,7 @@ namespace Nz::ShaderLang
void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue);
// Statements
ShaderAst::StatementPtr ParseAliasDeclaration();
ShaderAst::StatementPtr ParseBranchStatement();
ShaderAst::StatementPtr ParseConstStatement();
ShaderAst::StatementPtr ParseDiscardStatement();
@ -130,6 +131,7 @@ namespace Nz::ShaderLang
std::size_t tokenIndex = 0;
ShaderAst::ModulePtr module;
const Token* tokens;
bool parsingImportedModule = false;
};
Context* m_context;

View File

@ -12,6 +12,7 @@
#define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X)
#endif
NAZARA_SHADERLANG_TOKEN(Alias)
NAZARA_SHADERLANG_TOKEN(Arrow)
NAZARA_SHADERLANG_TOKEN(Assign)
NAZARA_SHADERLANG_TOKEN(BoolFalse)

View File

@ -405,6 +405,16 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(FunctionExpression& node)
{
auto clone = std::make_unique<FunctionExpression>();
clone->funcId = node.funcId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(IdentifierExpression& node)
{
auto clone = std::make_unique<IdentifierExpression>();
@ -429,6 +439,26 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(IntrinsicFunctionExpression& node)
{
auto clone = std::make_unique<IntrinsicFunctionExpression>();
clone->intrinsicId = node.intrinsicId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(StructTypeExpression& node)
{
auto clone = std::make_unique<StructTypeExpression>();
clone->structTypeId = node.structTypeId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(SwizzleExpression& node)
{
auto clone = std::make_unique<SwizzleExpression>();

View File

@ -72,6 +72,11 @@ namespace Nz::ShaderAst
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(FunctionExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/)
{
/* Nothing to do */
@ -83,6 +88,16 @@ namespace Nz::ShaderAst
param->Visit(*this);
}
void AstRecursiveVisitor::Visit(IntrinsicFunctionExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(StructTypeExpression& /*node*/)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{
if (node.expression)

View File

@ -153,6 +153,21 @@ namespace Nz::ShaderAst
Node(param);
}
void AstSerializerBase::Serialize(IntrinsicFunctionExpression& node)
{
SizeT(node.intrinsicId);
}
void AstSerializerBase::Serialize(StructTypeExpression& node)
{
SizeT(node.structTypeId);
}
void AstSerializerBase::Serialize(FunctionExpression& node)
{
SizeT(node.funcId);
}
void AstSerializerBase::Serialize(SwizzleExpression& node)
{
SizeT(node.componentCount);

View File

@ -72,6 +72,11 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(FunctionExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
@ -82,6 +87,16 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(IntrinsicFunctionExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(StructTypeExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{
if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1)

View File

@ -137,20 +137,22 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(CallFunctionExpression& node)
ExpressionPtr IndexRemapperVisitor::Clone(FunctionExpression& node)
{
CallFunctionExpressionPtr clone = static_unique_pointer_cast<CallFunctionExpression>(AstCloner::Clone(node));
FunctionExpressionPtr clone = static_unique_pointer_cast<FunctionExpression>(AstCloner::Clone(node));
const auto& targetFuncType = GetExpressionType(*node.targetFunction);
if (std::holds_alternative<FunctionType>(targetFuncType))
{
const auto& funcType = std::get<FunctionType>(targetFuncType);
assert(clone->funcId);
clone->funcId = Retrieve(m_context->newFuncIndices, clone->funcId);
FunctionType newFunc;
newFunc.funcIndex = Retrieve(m_context->newFuncIndices, funcType.funcIndex);
clone->cachedExpressionType = ExpressionType{ newFunc }; //< FIXME We should add FunctionExpression like VariableExpression to handle this
}
return clone;
}
ExpressionPtr IndexRemapperVisitor::Clone(StructTypeExpression& node)
{
StructTypeExpressionPtr clone = static_unique_pointer_cast<StructTypeExpression>(AstCloner::Clone(node));
assert(clone->structTypeId);
clone->structTypeId = Retrieve(m_context->newStructIndices, clone->structTypeId);
return clone;
}

View File

@ -7,6 +7,7 @@
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Core/Hash/SHA256.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Ast/AstConstantPropagationVisitor.hpp>
#include <Nazara/Shader/Ast/AstExportVisitor.hpp>
@ -114,6 +115,7 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Environment
{
Uuid moduleId;
std::shared_ptr<Environment> parentEnv;
std::vector<Identifier> identifiersInScope;
std::vector<Scope> scopes;
@ -121,21 +123,31 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Context
{
struct ModuleData
{
std::unordered_map<Uuid, DependencyCheckerVisitor::UsageSet> exportedSetByModule;
std::shared_ptr<Environment> environment;
std::unique_ptr<DependencyCheckerVisitor> dependenciesVisitor;
};
struct PendingFunction
{
DeclareFunctionStatement* cloneNode;
const DeclareFunctionStatement* node;
};
static constexpr std::size_t ModuleIdSentinel = std::numeric_limits<std::size_t>::max();
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::vector<std::shared_ptr<Environment>> moduleEnvironments;
std::vector<ModuleData> modules;
std::vector<PendingFunction> pendingFunctions;
std::vector<StatementPtr>* currentStatementList = nullptr;
std::unordered_map<Uuid, std::size_t> moduleByUuid;
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<UInt64> usedBindingIndexes;
std::shared_ptr<Environment> globalEnv;
std::shared_ptr<Environment> builtinEnv;
std::shared_ptr<Environment> currentEnv;
std::shared_ptr<Environment> globalEnv;
IdentifierList<ConstantValue> constantValues;
IdentifierList<FunctionData> functions;
IdentifierList<IdentifierData> aliases;
@ -144,22 +156,57 @@ namespace Nz::ShaderAst
IdentifierList<StructDescription*> structs;
IdentifierList<std::variant<ExpressionType, PartialType>> types;
IdentifierList<ExpressionType> variableTypes;
ModulePtr currentModule;
Options options;
CurrentFunctionData* currentFunction = nullptr;
};
ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error)
{
Context currentContext;
currentContext.options = options;
ModulePtr clone = std::make_shared<Module>(module.metadata);
clone->importedModules = module.importedModules;
Context currentContext;
currentContext.options = options;
currentContext.currentModule = clone;
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
// Register builtin env
m_context->builtinEnv = std::make_shared<Environment>();
m_context->currentEnv = m_context->builtinEnv;
RegisterBuiltin();
m_context->globalEnv = std::make_shared<Environment>();
m_context->globalEnv->moduleId = clone->metadata->moduleId;
m_context->globalEnv->parentEnv = m_context->builtinEnv;
for (std::size_t moduleId = 0; moduleId < clone->importedModules.size(); ++moduleId)
{
auto moduleEnv = std::make_shared<Environment>();
moduleEnv->moduleId = clone->importedModules[moduleId].module->metadata->moduleId;
moduleEnv->parentEnv = m_context->builtinEnv;
m_context->currentEnv = moduleEnv;
// Previous modules are visibles
for (std::size_t previousModuleId = 0; previousModuleId < moduleId; ++previousModuleId)
RegisterModule(clone->importedModules[previousModuleId].identifier, previousModuleId);
auto& importedModule = clone->importedModules[moduleId];
importedModule.module->rootNode = SanitizeInternal(*importedModule.module->rootNode, error);
if (!importedModule.module->rootNode)
return {};
m_context->moduleByUuid[importedModule.module->metadata->moduleId] = moduleId;
auto& moduleData = m_context->modules.emplace_back();
moduleData.environment = std::move(moduleEnv);
m_context->currentEnv = m_context->globalEnv;
RegisterModule(importedModule.identifier, moduleId);
}
m_context->currentEnv = m_context->globalEnv;
clone->rootNode = SanitizeInternal(*module.rootNode, error);
@ -211,7 +258,25 @@ namespace Nz::ShaderAst
if (node.identifiers.empty())
throw AstError{ "AccessIdentifierExpression must have at least one identifier" };
ExpressionPtr indexedExpr = CloneExpression(MandatoryExpr(node.expr));
MandatoryExpr(node.expr);
// Handle module access (TODO: Add namespace expression?)
if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1)
{
auto& identifierExpr = static_cast<IdentifierExpression&>(*node.expr);
const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier);
if (identifierData && identifierData->category == IdentifierCategory::Module)
{
std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index);
const auto& env = *m_context->modules[moduleIndex].environment;
identifierData = FindIdentifier(env, node.identifiers.front());
if (identifierData)
return HandleIdentifier(identifierData);
}
}
ExpressionPtr indexedExpr = CloneExpression(node.expr);
for (const std::string& identifier : node.identifiers)
{
if (identifier.empty())
@ -393,7 +458,10 @@ namespace Nz::ShaderAst
if (!m_context->currentFunction)
throw AstError{ "function calls must happen inside a function" };
std::size_t targetFuncIndex = std::get<FunctionType>(targetExprType).funcIndex;
if (targetExpr->GetType() != NodeType::FunctionExpression)
throw AstError{ "expected function expression" };
std::size_t targetFuncIndex = static_cast<FunctionExpression&>(*targetExpr).funcId;
auto clone = std::make_unique<CallFunctionExpression>();
clone->targetFunction = std::move(targetExpr);
@ -410,13 +478,18 @@ namespace Nz::ShaderAst
}
else if (IsIntrinsicFunctionType(targetExprType))
{
if (targetExpr->GetType() != NodeType::IntrinsicFunctionExpression)
throw AstError{ "expected intrinsic function expression" };
std::size_t targetIntrinsicId = static_cast<IntrinsicFunctionExpression&>(*targetExpr).intrinsicId;
std::vector<ExpressionPtr> parameters;
parameters.reserve(node.parameters.size());
for (const auto& param : node.parameters)
parameters.push_back(CloneExpression(param));
auto intrinsic = ShaderBuilder::Intrinsic(std::get<IntrinsicFunctionType>(targetExprType).intrinsic, std::move(parameters));
auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics.Retrieve(targetIntrinsicId), std::move(parameters));
Validate(*intrinsic);
return intrinsic;
@ -584,64 +657,7 @@ namespace Nz::ShaderAst
if (!identifierData)
throw AstError{ "unknown identifier " + node.identifier };
switch (identifierData->category)
{
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr;
constantExpr.constantId = identifierData->index;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}
case IdentifierCategory::Function:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = FunctionType{ identifierData->index };
return clone;
}
case IdentifierCategory::Intrinsic:
{
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index);
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = IntrinsicFunctionType{ intrinsicType };
return clone;
}
case IdentifierCategory::Struct:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = StructType{ identifierData->index };
return clone;
}
case IdentifierCategory::Type:
{
auto clone = AstCloner::Clone(node);
clone->cachedExpressionType = Type{ identifierData->index };
return clone;
}
case IdentifierCategory::Variable:
{
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index);
varExpr->variableId = identifierData->index;
return varExpr;
}
default:
throw AstError{ "unexpected identifier" };
}
return HandleIdentifier(identifierData);
}
ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node)
@ -872,6 +888,8 @@ namespace Nz::ShaderAst
if (node.returnType.HasValue())
clone->returnType = ResolveType(node.returnType);
else
clone->returnType = ExpressionType{ NoType{} };
if (node.depthWrite.HasValue())
clone->depthWrite = ComputeExprValue(node.depthWrite);
@ -1360,54 +1378,99 @@ namespace Nz::ShaderAst
if (!targetModule)
throw AstError{ "module " + ModulePathAsString() + " not found" };
targetModule->rootNode->sectionName = "Module " + targetModule->metadata->moduleId.ToString();
std::size_t moduleIndex;
m_context->currentEnv = m_context->moduleEnvironments.emplace_back(std::make_shared<Environment>());
CallOnExit restoreEnvOnExit([&] { m_context->currentEnv = m_context->globalEnv; });
const Uuid& moduleUuid = targetModule->metadata->moduleId;
auto it = m_context->moduleByUuid.find(moduleUuid);
if (it == m_context->moduleByUuid.end())
{
m_context->moduleByUuid[moduleUuid] = Context::ModuleIdSentinel;
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
// Generate module identifier (based on UUID)
const auto& moduleUuidBytes = moduleUuid.ToArray();
std::string error;
sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error);
if (!sanitizedModule)
throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error };
SHA256Hash hasher;
hasher.Begin();
hasher.Append(moduleUuidBytes.data(), moduleUuidBytes.size());
hasher.End();
std::string identifier = "__" + hasher.End().ToHex().substr(0, 8);
// Load new module
auto moduleEnvironment = std::make_shared<Environment>();
moduleEnvironment->parentEnv = m_context->builtinEnv;
auto previousEnv = m_context->currentEnv;
m_context->currentEnv = moduleEnvironment;
ModulePtr sanitizedModule = std::make_shared<Module>(targetModule->metadata);
std::string error;
sanitizedModule->rootNode = SanitizeInternal(*targetModule->rootNode, &error);
if (!sanitizedModule)
throw AstError{ "module " + ModulePathAsString() + " compilation failed: " + error };
moduleIndex = m_context->modules.size();
assert(m_context->modules.size() == moduleIndex);
auto& moduleData = m_context->modules.emplace_back();
moduleData.dependenciesVisitor = std::make_unique<DependencyCheckerVisitor>();
moduleData.dependenciesVisitor->Process(*sanitizedModule->rootNode);
moduleData.environment = std::move(moduleEnvironment);
assert(m_context->currentModule->importedModules.size() == moduleIndex);
auto& importedModule = m_context->currentModule->importedModules.emplace_back();
importedModule.identifier = identifier;
importedModule.module = std::move(sanitizedModule);
m_context->currentEnv = std::move(previousEnv);
RegisterModule(identifier, moduleIndex);
m_context->moduleByUuid[moduleUuid] = moduleIndex;
}
else
{
// Module has already been imported
moduleIndex = it->second;
if (moduleIndex == Context::ModuleIdSentinel)
throw AstError{ "circular import detected" };
}
auto& moduleData = m_context->modules[moduleIndex];
auto& exportedSet = moduleData.exportedSetByModule[m_context->currentEnv->moduleId];
// Extract exported nodes and their dependencies
DependencyCheckerVisitor::Config depConfig;
depConfig.usedShaderStages.Clear();
DependencyCheckerVisitor moduleDependencies;
moduleDependencies.Process(*sanitizedModule->rootNode, depConfig);
DependencyCheckerVisitor::UsageSet exportedSet;
MultiStatementPtr aliasBlock = std::make_unique<MultiStatement>();
std::vector<DeclareAliasStatementPtr> aliasStatements;
AstExportVisitor::Callbacks callbacks;
callbacks.onExportedStruct = [&](DeclareStructStatement& node)
{
assert(node.structIndex);
moduleDependencies.MarkStructAsUsed(*node.structIndex);
exportedSet.usedStructs.UnboundedSet(*node.structIndex);
auto alias = Clone(node);
// TODO: DeclareAlias
aliasBlock->statements.emplace_back(std::move(alias));
moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex);
if (!exportedSet.usedStructs.UnboundedTest(*node.structIndex))
{
exportedSet.usedStructs.UnboundedSet(*node.structIndex);
aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.description.name, ShaderBuilder::StructType(*node.structIndex)));
}
};
AstExportVisitor exportVisitor;
exportVisitor.Visit(*sanitizedModule->rootNode, callbacks);
exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks);
moduleDependencies.Resolve();
if (aliasStatements.empty())
return ShaderBuilder::NoOp();
//m_context->
// Register exported variables (FIXME: This shouldn't be necessary and could be handled by the IndexRemapperVisitor)
//m_context->importUsage = remappedExportedSet;
//CallOnExit restoreImportOnExit([&] { m_context->importUsage.reset(); });
// Register module and aliases
MultiStatementPtr aliasBlock = std::make_unique<MultiStatement>();
for (auto& aliasPtr : aliasStatements)
{
Validate(*aliasPtr);
aliasBlock->statements.push_back(std::move(aliasPtr));
}
return aliasBlock;
}
@ -1559,6 +1622,74 @@ namespace Nz::ShaderAst
throw std::runtime_error("internal error");
}
ExpressionPtr SanitizeVisitor::HandleIdentifier(const IdentifierData* identifierData)
{
switch (identifierData->category)
{
case IdentifierCategory::Constant:
{
// Replace IdentifierExpression by Constant(Value)Expression
ConstantExpression constantExpr;
constantExpr.constantId = identifierData->index;
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}
case IdentifierCategory::Function:
{
// Replace IdentifierExpression by FunctionExpression
auto funcExpr = std::make_unique<FunctionExpression>();
funcExpr->cachedExpressionType = FunctionType{ identifierData->index }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type
funcExpr->funcId = identifierData->index;
return funcExpr;
}
case IdentifierCategory::Intrinsic:
{
IntrinsicType intrinsicType = m_context->intrinsics.Retrieve(identifierData->index);
// Replace IdentifierExpression by IntrinsicFunctionExpression
auto intrinsicExpr = std::make_unique<IntrinsicFunctionExpression>();
intrinsicExpr->cachedExpressionType = IntrinsicFunctionType{ intrinsicType }; //< FIXME: Functions (and intrinsic) should be typed by their parameters/return type
intrinsicExpr->intrinsicId = identifierData->index;
return intrinsicExpr;
}
case IdentifierCategory::Struct:
{
// Replace IdentifierExpression by StructTypeExpression
auto structExpr = std::make_unique<StructTypeExpression>();
structExpr->cachedExpressionType = StructType{ identifierData->index };
structExpr->structTypeId = identifierData->index;
return structExpr;
}
case IdentifierCategory::Type:
{
auto clone = ShaderBuilder::Identifier("dummy");
clone->cachedExpressionType = Type{ identifierData->index };
return clone;
}
case IdentifierCategory::Variable:
{
// Replace IdentifierExpression by VariableExpression
auto varExpr = std::make_unique<VariableExpression>();
varExpr->cachedExpressionType = m_context->variableTypes.Retrieve(identifierData->index);
varExpr->variableId = identifierData->index;
return varExpr;
}
default:
throw AstError{ "unexpected identifier" };
}
}
Expression& SanitizeVisitor::MandatoryExpr(const ExpressionPtr& node) const
{
if (!node)
@ -1909,20 +2040,20 @@ namespace Nz::ShaderAst
return intrinsicIndex;
}
std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex)
std::size_t SanitizeVisitor::RegisterModule(std::string moduleIdentifier, std::size_t index)
{
if (FindIdentifier(moduleIdentifier))
throw AstError{ moduleIdentifier + " is already used" };
std::size_t intrinsicIndex = m_context->moduleIndices.Register(moduleIndex);
std::size_t moduleIndex = m_context->moduleIndices.Register(index);
m_context->currentEnv->identifiersInScope.push_back({
std::move(moduleIdentifier),
intrinsicIndex,
moduleIndex,
IdentifierCategory::Module
});
return intrinsicIndex;
return moduleIndex;
}
std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index)
@ -2169,11 +2300,7 @@ namespace Nz::ShaderAst
MultiStatementPtr SanitizeVisitor::SanitizeInternal(MultiStatement& rootNode, std::string* error)
{
MultiStatementPtr output;
PushScope(); //< Global scope
{
RegisterBuiltin();
// First pass, evaluate everything except function code
try
{
@ -2196,7 +2323,6 @@ namespace Nz::ShaderAst
ResolveFunctions();
}
PopScope();
return output;
}
@ -2261,7 +2387,7 @@ namespace Nz::ShaderAst
StackVector<TypeParameter> parameters = NazaraStackVector(TypeParameter, partialType.parameters.size());
for (std::size_t i = 0; i < partialType.parameters.size(); ++i)
{
ExpressionPtr indexExpr = CloneExpression(node.indices[i]);
const ExpressionPtr& indexExpr = node.indices[i];
switch (partialType.parameters[i])
{
case TypeParameterCategory::ConstantValue:

View File

@ -137,9 +137,16 @@ namespace Nz
std::string targetName;
};
struct StructData
{
std::string nameOverride;
const ShaderAst::StructDescription* desc;
};
std::optional<ShaderStageType> stage;
std::string moduleSuffix;
std::stringstream stream;
std::unordered_map<std::size_t, ShaderAst::StructDescription*> structs;
std::unordered_map<std::size_t, StructData> structs;
std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields;
@ -172,6 +179,7 @@ namespace Nz
else
targetAst = module.rootNode.get();
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
ShaderAst::StatementPtr optimizedAst;
if (states.optimize)
@ -193,6 +201,15 @@ namespace Nz
AppendHeader();
for (const auto& importedModule : targetModule.importedModules)
{
m_currentState->moduleSuffix = importedModule.identifier;
importedModule.module->rootNode->Visit(state.previsitor);
importedModule.module->rootNode->Visit(*this);
}
m_currentState->moduleSuffix = {};
targetAst->Visit(*this);
return state.stream.str();
@ -342,8 +359,8 @@ namespace Nz
void GlslWriter::Append(const ShaderAst::StructType& structType)
{
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex);
Append(structDesc->name);
const auto& structData = Retrieve(m_currentState->structs, structType.structIndex);
Append(structData.nameOverride);
}
void GlslWriter::Append(const ShaderAst::Type& /*type*/)
@ -630,9 +647,9 @@ namespace Nz
assert(IsStructType(parameter.type.GetResultingValue()));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex;
const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex);
const auto& structData = Retrieve(m_currentState->structs, structIndex);
AppendLine(structDesc->name, " ", varName, ";");
AppendLine(structData.nameOverride, " ", varName, ";");
for (const auto& [memberName, targetName] : m_currentState->inputFields)
AppendLine(varName, ".", memberName, " = ", targetName, ";");
@ -651,9 +668,9 @@ namespace Nz
void GlslWriter::HandleInOut()
{
auto AppendInOut = [this](const ShaderAst::StructDescription& structDesc, std::vector<State::InOutField>& fields, const char* keyword, const char* targetPrefix)
auto AppendInOut = [this](const State::StructData& structData, std::vector<State::InOutField>& fields, const char* keyword, const char* targetPrefix)
{
for (const auto& member : structDesc.members)
for (const auto& member : structData.desc->members)
{
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
@ -691,8 +708,6 @@ namespace Nz
const ShaderAst::DeclareFunctionStatement& node = *m_currentState->previsitor.entryPoint;
const ShaderAst::StructDescription* inputStruct = nullptr;
if (!node.parameters.empty())
{
assert(node.parameters.size() == 1);
@ -700,10 +715,10 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type.GetResultingValue()));
std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex;
inputStruct = Retrieve(m_currentState->structs, inputStructIndex);
const auto& inputStruct = Retrieve(m_currentState->structs, inputStructIndex);
AppendCommentSection("Inputs");
AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix);
AppendInOut(inputStruct, m_currentState->inputFields, "in", s_inputPrefix);
}
if (m_currentState->stage == ShaderStageType::Vertex && m_environment.flipYPosition)
@ -717,17 +732,21 @@ namespace Nz
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType.GetResultingValue()));
std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType.GetResultingValue()).structIndex;
const ShaderAst::StructDescription* outputStruct = Retrieve(m_currentState->structs, outputStructIndex);
const auto& outputStruct = Retrieve(m_currentState->structs, outputStructIndex);
AppendCommentSection("Outputs");
AppendInOut(*outputStruct, m_currentState->outputFields, "out", s_outputPrefix);
AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix);
}
}
void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc)
void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc, std::string structName)
{
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, desc);
State::StructData structData;
structData.desc = desc;
structData.nameOverride = std::move(structName);
m_currentState->structs.emplace(structIndex, std::move(structData));
}
void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName)
@ -1035,11 +1054,13 @@ namespace Nz
if (IsUniformType(externalVar.type.GetResultingValue()))
{
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue());
ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
if (structInfo->layout.HasValue())
isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140;
const auto& structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
if (structInfo.desc->layout.HasValue())
isStd140 = structInfo.desc->layout.GetResultingValue() == StructLayout::Std140;
}
std::string varName = externalVar.name + m_currentState->moduleSuffix;
if (!m_currentState->bindingMapping.empty() || isStd140)
Append("layout(");
@ -1074,15 +1095,15 @@ namespace Nz
if (IsUniformType(externalVar.type.GetResultingValue()))
{
Append("_NzBinding_");
AppendLine(externalVar.name);
AppendLine(varName);
EnterScope();
{
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue());
auto& structDesc = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
const auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue());
const auto& structData = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
bool first = true;
for (const auto& member : structDesc->members)
for (const auto& member : structData.desc->members)
{
if (member.cond.HasValue() && !member.cond.GetResultingValue())
continue;
@ -1099,10 +1120,10 @@ namespace Nz
LeaveScope(false);
Append(" ");
Append(externalVar.name);
Append(varName);
}
else
AppendVariableDeclaration(externalVar.type.GetResultingValue(), externalVar.name);
AppendVariableDeclaration(externalVar.type.GetResultingValue(), varName);
AppendLine(";");
@ -1110,7 +1131,7 @@ namespace Nz
AppendLine();
assert(externalVar.varIndex);
RegisterVariable(*externalVar.varIndex, externalVar.name);
RegisterVariable(*externalVar.varIndex, varName);
}
}
@ -1168,11 +1189,13 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
std::string structName = node.description.name + m_currentState->moduleSuffix;
assert(node.structIndex);
RegisterStruct(*node.structIndex, &node.description);
RegisterStruct(*node.structIndex, &node.description, structName);
Append("struct ");
AppendLine(node.description.name);
AppendLine(structName);
EnterScope();
{
bool first = true;
@ -1224,9 +1247,9 @@ namespace Nz
Append(";");
}
void GlslWriter::Visit(ShaderAst::ImportStatement& /*node*/)
void GlslWriter::Visit(ShaderAst::ImportStatement& node)
{
/* nothing to do */
throw std::runtime_error("unexpected import statement, is the shader sanitized properly?");
}
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
@ -1254,7 +1277,7 @@ namespace Nz
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
assert(IsStructType(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex);
const auto& structData = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName;
if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression)
@ -1262,7 +1285,7 @@ namespace Nz
else
{
AppendLine();
Append(structDesc->name, " ", s_outputVarName, " = ");
Append(structData.nameOverride, " ", s_outputVarName, " = ");
node.returnExpr->Visit(*this);
AppendLine(";");

View File

@ -31,73 +31,95 @@ namespace Nz
{
const ShaderAst::ExpressionValue<UInt32>& bindingIndex;
inline bool HasValue() const { return bindingIndex.HasValue(); }
bool HasValue() const { return bindingIndex.HasValue(); }
};
struct LangWriter::BuiltinAttribute
{
const ShaderAst::ExpressionValue<ShaderAst::BuiltinEntry>& builtin;
inline bool HasValue() const { return builtin.HasValue(); }
bool HasValue() const { return builtin.HasValue(); }
};
struct LangWriter::DepthWriteAttribute
{
const ShaderAst::ExpressionValue<ShaderAst::DepthWriteMode>& writeMode;
inline bool HasValue() const { return writeMode.HasValue(); }
bool HasValue() const { return writeMode.HasValue(); }
};
struct LangWriter::EarlyFragmentTestsAttribute
{
const ShaderAst::ExpressionValue<bool>& earlyFragmentTests;
inline bool HasValue() const { return earlyFragmentTests.HasValue(); }
bool HasValue() const { return earlyFragmentTests.HasValue(); }
};
struct LangWriter::EntryAttribute
{
const ShaderAst::ExpressionValue<ShaderStageType>& stageType;
inline bool HasValue() const { return stageType.HasValue(); }
bool HasValue() const { return stageType.HasValue(); }
};
struct LangWriter::LangVersionAttribute
{
UInt32 version;
bool HasValue() const { return true; }
};
struct LangWriter::LayoutAttribute
{
const ShaderAst::ExpressionValue<StructLayout>& layout;
inline bool HasValue() const { return layout.HasValue(); }
bool HasValue() const { return layout.HasValue(); }
};
struct LangWriter::LocationAttribute
{
const ShaderAst::ExpressionValue<UInt32>& locationIndex;
inline bool HasValue() const { return locationIndex.HasValue(); }
bool HasValue() const { return locationIndex.HasValue(); }
};
struct LangWriter::SetAttribute
{
const ShaderAst::ExpressionValue<UInt32>& setIndex;
inline bool HasValue() const { return setIndex.HasValue(); }
bool HasValue() const { return setIndex.HasValue(); }
};
struct LangWriter::UnrollAttribute
{
const ShaderAst::ExpressionValue<ShaderAst::LoopUnroll>& unroll;
inline bool HasValue() const { return unroll.HasValue(); }
bool HasValue() const { return unroll.HasValue(); }
};
struct LangWriter::UuidAttribute
{
Uuid uuid;
bool HasValue() const { return true; }
};
struct LangWriter::State
{
struct Identifier
{
std::size_t moduleIndex;
std::string name;
};
const States* states = nullptr;
ShaderAst::Module* module;
std::size_t currentModuleIndex;
std::stringstream stream;
std::unordered_map<std::size_t, std::string> constantNames;
std::unordered_map<std::size_t, ShaderAst::StructDescription*> structs;
std::unordered_map<std::size_t, std::string> variableNames;
std::unordered_map<std::size_t, Identifier> constantNames;
std::unordered_map<std::size_t, Identifier> structs;
std::unordered_map<std::size_t, Identifier> variableNames;
std::vector<std::string> moduleNames;
bool isInEntryPoint = false;
unsigned int indentLevel = 0;
};
@ -116,6 +138,22 @@ namespace Nz
AppendHeader();
// Register imported modules
m_currentState->currentModuleIndex = 0;
for (const auto& importedModule : sanitizedModule->importedModules)
{
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
AppendLine("module ", importedModule.identifier);
EnterScope();
importedModule.module->rootNode->Visit(*this);
LeaveScope(true);
m_currentState->currentModuleIndex++;
m_currentState->moduleNames.push_back(importedModule.identifier);
}
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
sanitizedModule->rootNode->Visit(*this);
return state.stream.str();
@ -213,8 +251,7 @@ namespace Nz
void LangWriter::Append(const ShaderAst::StructType& structType)
{
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex);
Append(structDesc->name);
AppendIdentifier(m_currentState->structs, structType.structIndex);
}
void LangWriter::Append(const ShaderAst::Type& /*type*/)
@ -292,31 +329,31 @@ namespace Nz
AppendAttributesInternal(first, secondParam, std::forward<Rest>(params)...);
}
void LangWriter::AppendAttribute(BindingAttribute binding)
void LangWriter::AppendAttribute(BindingAttribute attribute)
{
if (!binding.HasValue())
if (!attribute.HasValue())
return;
Append("binding(");
if (binding.bindingIndex.IsResultingValue())
Append(binding.bindingIndex.GetResultingValue());
if (attribute.bindingIndex.IsResultingValue())
Append(attribute.bindingIndex.GetResultingValue());
else
binding.bindingIndex.GetExpression()->Visit(*this);
attribute.bindingIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(BuiltinAttribute builtin)
void LangWriter::AppendAttribute(BuiltinAttribute attribute)
{
if (!builtin.HasValue())
if (!attribute.HasValue())
return;
Append("builtin(");
if (builtin.builtin.IsResultingValue())
if (attribute.builtin.IsResultingValue())
{
switch (builtin.builtin.GetResultingValue())
switch (attribute.builtin.GetResultingValue())
{
case ShaderAst::BuiltinEntry::FragCoord:
Append("fragcoord");
@ -332,21 +369,21 @@ namespace Nz
}
}
else
builtin.builtin.GetExpression()->Visit(*this);
attribute.builtin.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite)
void LangWriter::AppendAttribute(DepthWriteAttribute attribute)
{
if (!depthWrite.HasValue())
if (!attribute.HasValue())
return;
Append("depth_write(");
if (depthWrite.writeMode.IsResultingValue())
if (attribute.writeMode.IsResultingValue())
{
switch (depthWrite.writeMode.GetResultingValue())
switch (attribute.writeMode.GetResultingValue())
{
case ShaderAst::DepthWriteMode::Greater:
Append("greater");
@ -366,41 +403,41 @@ namespace Nz
}
}
else
depthWrite.writeMode.GetExpression()->Visit(*this);
attribute.writeMode.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests)
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute attribute)
{
if (!earlyFragmentTests.HasValue())
if (!attribute.HasValue())
return;
Append("early_fragment_tests(");
if (earlyFragmentTests.earlyFragmentTests.IsResultingValue())
if (attribute.earlyFragmentTests.IsResultingValue())
{
if (earlyFragmentTests.earlyFragmentTests.GetResultingValue())
if (attribute.earlyFragmentTests.GetResultingValue())
Append("true");
else
Append("false");
}
else
earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this);
attribute.earlyFragmentTests.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(EntryAttribute entry)
void LangWriter::AppendAttribute(EntryAttribute attribute)
{
if (!entry.HasValue())
if (!attribute.HasValue())
return;
Append("entry(");
if (entry.stageType.IsResultingValue())
if (attribute.stageType.IsResultingValue())
{
switch (entry.stageType.GetResultingValue())
switch (attribute.stageType.GetResultingValue())
{
case ShaderStageType::Fragment:
Append("frag");
@ -412,20 +449,39 @@ namespace Nz
}
}
else
entry.stageType.GetExpression()->Visit(*this);
attribute.stageType.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(LayoutAttribute entry)
void LangWriter::AppendAttribute(LangVersionAttribute attribute)
{
if (!entry.HasValue())
UInt32 shaderLangVersion = attribute.version;
UInt32 majorVersion = shaderLangVersion / 100;
shaderLangVersion -= majorVersion * 100;
UInt32 minorVersion = shaderLangVersion / 10;
shaderLangVersion -= minorVersion * 100;
UInt32 patchVersion = shaderLangVersion;
// nzsl_version
Append("nzsl_version(\"", majorVersion, ".", minorVersion);
if (patchVersion != 0)
Append(".", patchVersion);
Append("\")");
}
void LangWriter::AppendAttribute(LayoutAttribute attribute)
{
if (!attribute.HasValue())
return;
Append("layout(");
if (entry.layout.IsResultingValue())
if (attribute.layout.IsResultingValue())
{
switch (entry.layout.GetResultingValue())
switch (attribute.layout.GetResultingValue())
{
case StructLayout::Packed:
Append("packed");
@ -437,50 +493,50 @@ namespace Nz
}
}
else
entry.layout.GetExpression()->Visit(*this);
attribute.layout.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(LocationAttribute location)
void LangWriter::AppendAttribute(LocationAttribute attribute)
{
if (!location.HasValue())
if (!attribute.HasValue())
return;
Append("location(");
if (location.locationIndex.IsResultingValue())
Append(location.locationIndex.GetResultingValue());
if (attribute.locationIndex.IsResultingValue())
Append(attribute.locationIndex.GetResultingValue());
else
location.locationIndex.GetExpression()->Visit(*this);
attribute.locationIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(SetAttribute set)
void LangWriter::AppendAttribute(SetAttribute attribute)
{
if (!set.HasValue())
if (!attribute.HasValue())
return;
Append("set(");
if (set.setIndex.IsResultingValue())
Append(set.setIndex.GetResultingValue());
if (attribute.setIndex.IsResultingValue())
Append(attribute.setIndex.GetResultingValue());
else
set.setIndex.GetExpression()->Visit(*this);
attribute.setIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(UnrollAttribute unroll)
void LangWriter::AppendAttribute(UnrollAttribute attribute)
{
if (!unroll.HasValue())
if (!attribute.HasValue())
return;
Append("unroll(");
if (unroll.unroll.IsResultingValue())
if (attribute.unroll.IsResultingValue())
{
switch (unroll.unroll.GetResultingValue())
switch (attribute.unroll.GetResultingValue())
{
case ShaderAst::LoopUnroll::Always:
Append("always");
@ -499,7 +555,12 @@ namespace Nz
}
}
else
unroll.unroll.GetExpression()->Visit(*this);
attribute.unroll.GetExpression()->Visit(*this);
}
void LangWriter::AppendAttribute(UuidAttribute attribute)
{
Append("uuid(\"", attribute.uuid.ToString(), "\")");
}
void LangWriter::AppendComment(const std::string& section)
@ -539,6 +600,16 @@ namespace Nz
m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t');
}
template<typename T>
void LangWriter::AppendIdentifier(const T& map, std::size_t id)
{
const auto& structIdentifier = Retrieve(map, id);
if (structIdentifier.moduleIndex != m_currentState->currentModuleIndex)
Append(m_currentState->moduleNames[structIdentifier.moduleIndex], '.');
Append(structIdentifier.name);
}
template<typename... Args>
void LangWriter::AppendLine(Args&&... params)
{
@ -586,20 +657,32 @@ namespace Nz
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(constantName);
assert(m_currentState->constantNames.find(constantIndex) == m_currentState->constantNames.end());
m_currentState->constantNames.emplace(constantIndex, std::move(constantName));
m_currentState->constantNames.emplace(constantIndex, std::move(identifier));
}
void LangWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc)
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(structName);
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, desc);
m_currentState->structs.emplace(structIndex, std::move(identifier));
}
void LangWriter::RegisterVariable(std::size_t varIndex, std::string varName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(varName);
assert(m_currentState->variableNames.find(varIndex) == m_currentState->variableNames.end());
m_currentState->variableNames.emplace(varIndex, std::move(varName));
m_currentState->variableNames.emplace(varIndex, std::move(identifier));
}
void LangWriter::ScopeVisit(ShaderAst::Statement& node)
@ -770,7 +853,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{
throw std::runtime_error("TODO"); //< missing registering
//throw std::runtime_error("TODO"); //< missing registering
assert(node.aliasIndex);
@ -828,7 +911,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ConstantExpression& node)
{
Append(Retrieve(m_currentState->constantNames, node.constantId));
AppendIdentifier(m_currentState->constantNames, node.constantId);
}
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
@ -908,7 +991,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
assert(node.structIndex);
RegisterStruct(*node.structIndex, &node.description);
RegisterStruct(*node.structIndex, node.description.name);
AppendAttributes(true, LayoutAttribute{ node.description.layout });
Append("struct ");
@ -1068,6 +1151,11 @@ namespace Nz
Append(")");
}
void LangWriter::Visit(ShaderAst::StructTypeExpression& node)
{
AppendIdentifier(m_currentState->structs, node.structTypeId);
}
void LangWriter::Visit(ShaderAst::MultiStatement& node)
{
if (!node.sectionName.empty())
@ -1100,7 +1188,7 @@ namespace Nz
{
EnterScope();
node.statement->Visit(*this);
LeaveScope(true);
LeaveScope();
}
void LangWriter::Visit(ShaderAst::SwizzleExpression& node)
@ -1115,8 +1203,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::VariableExpression& node)
{
const std::string& varName = Retrieve(m_currentState->variableNames, node.variableId);
Append(varName);
AppendIdentifier(m_currentState->variableNames, node.variableId);
}
void LangWriter::Visit(ShaderAst::UnaryExpression& node)
@ -1150,22 +1237,7 @@ namespace Nz
void LangWriter::AppendHeader()
{
UInt32 shaderLangVersion = m_currentState->module->metadata->shaderLangVersion;
UInt32 majorVersion = shaderLangVersion / 100;
shaderLangVersion -= majorVersion * 100;
UInt32 minorVersion = shaderLangVersion / 10;
shaderLangVersion -= minorVersion * 100;
UInt32 patchVersion = shaderLangVersion;
// nzsl_version
Append("[nzsl_version(\"", majorVersion, ".", minorVersion);
if (patchVersion != 0)
Append(".", patchVersion);
AppendLine("\")]");
AppendAttributes(true, LangVersionAttribute{ m_currentState->module->metadata->shaderLangVersion });
AppendLine("module;");
AppendLine();
}

View File

@ -40,6 +40,7 @@ namespace Nz::ShaderLang
ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "alias", TokenType::Alias },
{ "const", TokenType::Const },
{ "const_select", TokenType::ConstSelect },
{ "discard", TokenType::Discard },

View File

@ -246,6 +246,9 @@ namespace Nz::ShaderLang
void Parser::ParseModuleStatement(std::vector<ShaderAst::ExprValue> attributes)
{
if (m_context->parsingImportedModule)
throw UnexpectedToken{};
Expect(Advance(), TokenType::Module);
std::optional<UInt32> moduleVersion;
@ -343,7 +346,21 @@ namespace Nz::ShaderLang
const std::string& identifier = std::get<std::string>(Peek().data);
Consume();
module->rootNode->statements = ParseStatementList();
Expect(Advance(), TokenType::OpenCurlyBracket);
m_context->parsingImportedModule = true;
while (Peek().type != TokenType::ClosingCurlyBracket)
{
ShaderAst::StatementPtr statement = ParseRootStatement();
if (!statement)
throw UnexpectedToken{}; //< "unexpected end of file"
module->rootNode->statements.push_back(std::move(statement));
}
Consume(); //< Consume ClosingCurlyBracket
m_context->parsingImportedModule = false;
auto& importedModule = m_context->module->importedModules.emplace_back();
importedModule.module = std::move(module);
@ -380,6 +397,21 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Semicolon);
}
ShaderAst::StatementPtr Parser::ParseAliasDeclaration()
{
Expect(Advance(), TokenType::Alias);
std::string name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Assign);
ShaderAst::ExpressionPtr expr = ParseExpression();
Expect(Advance(), TokenType::Semicolon);
return ShaderBuilder::DeclareAlias(std::move(name), std::move(expr));
}
ShaderAst::StatementPtr Parser::ParseBranchStatement()
{
std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>();
@ -756,6 +788,12 @@ namespace Nz::ShaderLang
const Token& nextToken = Peek();
switch (nextToken.type)
{
case TokenType::Alias:
if (!attributes.empty())
throw UnexpectedToken{};
return ParseAliasDeclaration();
case TokenType::Const:
if (!attributes.empty())
throw UnexpectedToken{};

View File

@ -517,6 +517,8 @@ namespace Nz
else
targetAst = module.rootNode.get();
const ShaderAst::Module& targetModule = (sanitizedModule) ? *sanitizedModule : module;
ShaderAst::StatementPtr optimizedAst;
if (states.optimize)
{
@ -542,6 +544,9 @@ namespace Nz
// Register all extended instruction sets
PreVisitor preVisitor(state.constantTypeCache, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
importedModule.module->rootNode->Visit(preVisitor);
targetAst->Visit(preVisitor);
m_currentState->preVisitor = &preVisitor;
@ -554,6 +559,9 @@ namespace Nz
func.funcId = AllocateResultId();
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
for (const auto& importedModule : targetModule.importedModules)
importedModule.module->rootNode->Visit(visitor);
targetAst->Visit(visitor);
AppendHeader();