diff --git a/bin/resources/lighting.nzsl b/bin/resources/lighting.nzsl index f7275ea99..c4119b9c5 100644 --- a/bin/resources/lighting.nzsl +++ b/bin/resources/lighting.nzsl @@ -60,7 +60,6 @@ fn main(input: VertOut) -> FragOut let position = positionTexture.Sample(input.uv).xyz; let distance = length(lightParameters.position - position); - let attenuation = 1.0 / (lightParameters.constant + lightParameters.linear * distance + lightParameters.quadratic * (distance * distance)); let posToLight = (lightParameters.position - position) / distance; let lambert = dot(normal, posToLight); @@ -68,6 +67,7 @@ fn main(input: VertOut) -> FragOut let curAngle = dot(lightParameters.direction, -posToLight); let innerMinusOuterAngle = lightParameters.innerAngle - lightParameters.outerAngle; + let attenuation = compute_attenuation(distance); attenuation = attenuation * max((curAngle - lightParameters.outerAngle) / innerMinusOuterAngle, 0.0); let output: FragOut; @@ -85,3 +85,8 @@ fn main(input: VertIn) -> VertOut return output; } + +fn compute_attenuation(distance: f32) -> f32 +{ + return 1.0 / (lightParameters.constant + lightParameters.linear * distance + lightParameters.quadratic * (distance * distance)); +} diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index d4f5fc2c9..809c03f25 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -40,6 +40,8 @@ namespace Nz::ShaderAst virtual ExpressionPtr Clone(AccessMemberIndexExpression& node); virtual ExpressionPtr Clone(AssignExpression& node); virtual ExpressionPtr Clone(BinaryExpression& node); + virtual ExpressionPtr Clone(CallFunctionExpression& node); + virtual ExpressionPtr Clone(CallMethodExpression& node); virtual ExpressionPtr Clone(CastExpression& node); virtual ExpressionPtr Clone(ConditionalExpression& node); virtual ExpressionPtr Clone(ConstantExpression& node); diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index 47aa20c68..f68fe4276 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -30,6 +30,8 @@ NAZARA_SHADERAST_EXPRESSION(AccessMemberIdentifierExpression) NAZARA_SHADERAST_EXPRESSION(AccessMemberIndexExpression) NAZARA_SHADERAST_EXPRESSION(AssignExpression) NAZARA_SHADERAST_EXPRESSION(BinaryExpression) +NAZARA_SHADERAST_EXPRESSION(CallFunctionExpression) +NAZARA_SHADERAST_EXPRESSION(CallMethodExpression) NAZARA_SHADERAST_EXPRESSION(CastExpression) NAZARA_SHADERAST_EXPRESSION(ConditionalExpression) NAZARA_SHADERAST_EXPRESSION(ConstantExpression) diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index c82c98ab6..13a78ce1c 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -24,6 +24,8 @@ namespace Nz::ShaderAst void Visit(AccessMemberIndexExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; + void Visit(CallFunctionExpression& node) override; + void Visit(CallMethodExpression& node) override; void Visit(CastExpression& node) override; void Visit(ConditionalExpression& node) override; void Visit(ConstantExpression& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index cbae0043f..20af6f2b4 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -27,6 +27,8 @@ namespace Nz::ShaderAst void Serialize(AccessMemberIndexExpression& node); void Serialize(AssignExpression& node); void Serialize(BinaryExpression& node); + void Serialize(CallFunctionExpression& node); + void Serialize(CallMethodExpression& node); void Serialize(CastExpression& node); void Serialize(ConditionalExpression& node); void Serialize(ConstantExpression& node); diff --git a/include/Nazara/Shader/Ast/AstUtils.hpp b/include/Nazara/Shader/Ast/AstUtils.hpp index 501f6fb45..560e482b6 100644 --- a/include/Nazara/Shader/Ast/AstUtils.hpp +++ b/include/Nazara/Shader/Ast/AstUtils.hpp @@ -35,6 +35,8 @@ namespace Nz::ShaderAst void Visit(AccessMemberIndexExpression& node) override; void Visit(AssignExpression& node) override; void Visit(BinaryExpression& node) override; + void Visit(CallFunctionExpression& node) override; + void Visit(CallMethodExpression& node) override; void Visit(CastExpression& node) override; void Visit(ConditionalExpression& node) override; void Visit(ConstantExpression& node) override; diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 4bfc7b51a..f44c74693 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -102,6 +102,25 @@ namespace Nz::ShaderAst ExpressionPtr right; }; + struct NAZARA_SHADER_API CallFunctionExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::variant targetFunction; + std::vector parameters; + }; + + struct NAZARA_SHADER_API CallMethodExpression : public Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + ExpressionPtr object; + std::string methodName; + std::vector parameters; + }; + struct NAZARA_SHADER_API CastExpression : public Expression { NodeType GetType() const override; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 9c9f79202..86bf41b2c 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -20,7 +21,7 @@ namespace Nz::ShaderAst public: struct Options; - inline SanitizeVisitor(); + SanitizeVisitor() = default; SanitizeVisitor(const SanitizeVisitor&) = delete; SanitizeVisitor(SanitizeVisitor&&) = delete; ~SanitizeVisitor() = default; @@ -47,6 +48,7 @@ namespace Nz::ShaderAst ExpressionPtr Clone(AccessMemberIdentifierExpression& node) override; ExpressionPtr Clone(AssignExpression& node) override; ExpressionPtr Clone(BinaryExpression& node) override; + ExpressionPtr Clone(CallFunctionExpression& node) override; ExpressionPtr Clone(CastExpression& node) override; ExpressionPtr Clone(ConditionalExpression& node) override; ExpressionPtr Clone(ConstantExpression& node) override; @@ -76,10 +78,11 @@ namespace Nz::ShaderAst void PushScope(); void PopScope(); - inline std::size_t RegisterFunction(std::string name); - inline std::size_t RegisterOption(std::string name, ExpressionType type); - inline std::size_t RegisterStruct(std::string name, StructDescription description); - inline std::size_t RegisterVariable(std::string name, ExpressionType type); + std::size_t RegisterFunction(DeclareFunctionStatement* funcDecl); + std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); + std::size_t RegisterOption(std::string name, ExpressionType type); + std::size_t RegisterStruct(std::string name, StructDescription description); + std::size_t RegisterVariable(std::string name, ExpressionType type); std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const IdentifierType& identifierType); @@ -89,37 +92,33 @@ namespace Nz::ShaderAst void SanitizeIdentifier(std::string& identifier); - struct Alias - { - std::variant value; - }; - - struct Option - { - std::size_t optionIndex; - }; - - struct Struct - { - std::size_t structIndex; - }; - - struct Variable - { - std::size_t varIndex; - }; + void Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration); + void Validate(IntrinsicExpression& node); struct Identifier { + enum class Type + { + Alias, + Function, + Intrinsic, + Option, + Struct, + Variable + }; + std::string name; - std::variant value; + std::size_t index; + Type type; }; - std::size_t m_nextFuncIndex; + std::unordered_map> m_functionDeclarations; std::vector m_identifiersInScope; + std::vector m_functions; + std::vector m_intrinsics; std::vector m_options; std::vector m_structs; - std::vector m_variables; + std::vector m_variableTypes; std::vector m_scopeSizes; struct Context; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.inl b/include/Nazara/Shader/Ast/SanitizeVisitor.inl index 59332d55f..4bc69711a 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.inl +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.inl @@ -7,11 +7,6 @@ namespace Nz::ShaderAst { - inline SanitizeVisitor::SanitizeVisitor() : - m_nextFuncIndex(0) - { - } - inline StatementPtr SanitizeVisitor::Sanitize(const StatementPtr& statement, std::string* error) { return Sanitize(statement, {}, error); @@ -26,56 +21,6 @@ namespace Nz::ShaderAst return &*it; } - inline std::size_t SanitizeVisitor::RegisterFunction(std::string name) - { - return m_nextFuncIndex++; - } - - inline std::size_t SanitizeVisitor::RegisterOption(std::string name, ExpressionType type) - { - std::size_t optionIndex = m_options.size(); - m_options.emplace_back(std::move(type)); - - m_identifiersInScope.push_back({ - std::move(name), - Option { - optionIndex - } - }); - - return optionIndex; - } - - inline std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description) - { - std::size_t structIndex = m_structs.size(); - m_structs.emplace_back(std::move(description)); - - m_identifiersInScope.push_back({ - std::move(name), - Struct { - structIndex - } - }); - - return structIndex; - } - - inline std::size_t SanitizeVisitor::RegisterVariable(std::string name, ExpressionType type) - { - std::size_t varIndex = m_variables.size(); - m_variables.emplace_back(std::move(type)); - - m_identifiersInScope.push_back({ - std::move(name), - Variable { - varIndex - } - }); - - return varIndex; - } - inline StatementPtr Sanitize(const StatementPtr& ast, std::string* error) { SanitizeVisitor sanitizer; diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 6aa4c5117..ecf1a0923 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace Nz { @@ -60,8 +61,9 @@ namespace Nz template void Append(const T& param); template void Append(const T1& firstParam, const T2& secondParam, Args&&... params); void AppendCommentSection(const std::string& section); + void AppendFunctionDeclaration(const ShaderAst::DeclareFunctionStatement& node, bool forward = false); void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers); - void AppendHeader(); + void AppendHeader(const std::vector& forwardFunctionDeclarations); void AppendLine(const std::string& txt = {}); template void AppendLine(Args&&... params); void AppendStatementList(std::vector& statements); @@ -72,6 +74,7 @@ namespace Nz void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node); void HandleInOut(); + void RegisterFunction(std::size_t funcIndex, std::string funcName); void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc); void RegisterVariable(std::size_t varIndex, std::string varName); @@ -80,6 +83,7 @@ namespace Nz void Visit(ShaderAst::AccessMemberIndexExpression& node) override; void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; + void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::ConditionalExpression& node) override; void Visit(ShaderAst::ConstantExpression& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index f80e595ff..394b035f1 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -38,6 +38,11 @@ namespace Nz::ShaderBuilder inline std::unique_ptr operator()(std::vector condStatements, ShaderAst::StatementPtr elseStatement = nullptr) const; }; + struct CallFunction + { + inline std::unique_ptr operator()(std::string functionName, std::vector parameters) const; + }; + struct Cast { inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::array expressions) const; @@ -133,6 +138,7 @@ namespace Nz::ShaderBuilder constexpr Impl::Assign Assign; constexpr Impl::Binary Binary; constexpr Impl::Branch Branch; + constexpr Impl::CallFunction CallFunction; constexpr Impl::Cast Cast; constexpr Impl::ConditionalExpression ConditionalExpression; constexpr Impl::ConditionalStatement ConditionalStatement; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index dba57a872..33227d224 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -58,6 +58,15 @@ namespace Nz::ShaderBuilder return branchNode; } + inline std::unique_ptr Impl::CallFunction::operator()(std::string functionName, std::vector parameters) const + { + auto callFunctionExpression = std::make_unique(); + callFunctionExpression->targetFunction = std::move(functionName); + callFunctionExpression->parameters = std::move(parameters); + + return callFunctionExpression; + } + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array expressions) const { auto castNode = std::make_unique(); @@ -138,7 +147,7 @@ namespace Nz::ShaderBuilder return declareFunctionNode; } - inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareOption::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const + inline std::unique_ptr Impl::DeclareOption::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const { auto declareOptionNode = std::make_unique(); declareOptionNode->optName = std::move(name); @@ -156,7 +165,7 @@ namespace Nz::ShaderBuilder return declareStructNode; } - inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const + inline std::unique_ptr Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const { auto declareVariableNode = std::make_unique(); declareVariableNode->varName = std::move(name); @@ -165,7 +174,7 @@ namespace Nz::ShaderBuilder return declareVariableNode; } - inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const + inline std::unique_ptr Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const { auto declareVariableNode = std::make_unique(); declareVariableNode->varName = std::move(name); diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 4733b9eb0..525906b0c 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace Nz::ShaderLang { @@ -69,7 +70,7 @@ namespace Nz::ShaderLang // Flow control const Token& Advance(); void Consume(std::size_t count = 1); - ShaderAst::ExpressionType DecodeType(const std::string& identifier); + std::optional DecodeType(const std::string& identifier); void EnterScope(); const Token& Expect(const Token& token, TokenType type); const Token& ExpectNot(const Token& token, TokenType type); diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 5f52ac640..2a5531236 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -45,6 +45,7 @@ namespace Nz void Visit(ShaderAst::AssignExpression& node) override; void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::BranchStatement& node) override; + void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::ConditionalExpression& node) override; void Visit(ShaderAst::ConditionalStatement& node) override; @@ -92,7 +93,6 @@ namespace Nz ShaderStageType stageType; std::optional inputStruct; std::optional outputStructTypeId; - std::size_t funcIndex; std::vector inputs; std::vector outputs; }; @@ -101,6 +101,11 @@ namespace Nz { std::optional entryPointData; + struct FuncCall + { + std::size_t firstVarIndex; + }; + struct Parameter { UInt32 pointerTypeId; @@ -113,7 +118,9 @@ namespace Nz UInt32 varId; }; + std::size_t funcIndex; std::string name; + std::vector funcCalls; std::vector parameters; std::vector variables; std::unordered_map varIndexToVarId; @@ -138,6 +145,7 @@ namespace Nz inline void RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass); std::size_t m_extVarIndex; + std::size_t m_funcCallIndex; std::size_t m_funcIndex; std::vector m_scopeSizes; std::vector& m_funcData; diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 999d445e0..ad37f7e0b 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -202,6 +202,36 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr AstCloner::Clone(CallFunctionExpression& node) + { + auto clone = std::make_unique(); + clone->targetFunction = node.targetFunction; + + clone->parameters.reserve(node.parameters.size()); + for (auto& parameter : node.parameters) + clone->parameters.push_back(CloneExpression(parameter)); + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + + ExpressionPtr AstCloner::Clone(CallMethodExpression& node) + { + auto clone = std::make_unique(); + clone->methodName = node.methodName; + + clone->object = CloneExpression(node.object); + + clone->parameters.reserve(node.parameters.size()); + for (auto& parameter : node.parameters) + clone->parameters.push_back(CloneExpression(parameter)); + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + ExpressionPtr AstCloner::Clone(CastExpression& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index 0e646be65..6c2c68823 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -29,6 +29,20 @@ namespace Nz::ShaderAst node.right->Visit(*this); } + void AstRecursiveVisitor::Visit(CallFunctionExpression& node) + { + for (auto& param : node.parameters) + param->Visit(*this); + } + + void AstRecursiveVisitor::Visit(CallMethodExpression& node) + { + node.object->Visit(*this); + + for (auto& param : node.parameters) + param->Visit(*this); + } + void AstRecursiveVisitor::Visit(CastExpression& node) { for (auto& expr : node.expressions) diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 78b1647ea..637ff04c3 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -65,6 +65,45 @@ namespace Nz::ShaderAst Node(node.right); } + void AstSerializerBase::Serialize(CallFunctionExpression& node) + { + UInt32 typeIndex; + if (IsWriting()) + typeIndex = UInt32(node.targetFunction.index()); + + Value(typeIndex); + + // Waiting for template lambda in C++20 + auto SerializeValue = [&](auto dummyType) + { + using T = std::decay_t; + + auto& value = (IsWriting()) ? std::get(node.targetFunction) : node.targetFunction.emplace(); + Value(value); + }; + + static_assert(std::variant_size_v == 2); + switch (typeIndex) + { + case 0: SerializeValue(std::string()); break; + case 1: SerializeValue(std::size_t()); break; + } + + Container(node.parameters); + for (auto& param : node.parameters) + Node(param); + } + + void AstSerializerBase::Serialize(CallMethodExpression& node) + { + Node(node.object); + Value(node.methodName); + + Container(node.parameters); + for (auto& param : node.parameters) + Node(param); + } + void AstSerializerBase::Serialize(CastExpression& node) { Type(node.targetType); diff --git a/src/Nazara/Shader/Ast/AstUtils.cpp b/src/Nazara/Shader/Ast/AstUtils.cpp index ae21847dd..798bba828 100644 --- a/src/Nazara/Shader/Ast/AstUtils.cpp +++ b/src/Nazara/Shader/Ast/AstUtils.cpp @@ -33,6 +33,16 @@ namespace Nz::ShaderAst m_expressionCategory = ExpressionCategory::RValue; } + void ShaderAstValueCategory::Visit(CallFunctionExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::RValue; + } + + void ShaderAstValueCategory::Visit(CallMethodExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::RValue; + } + void ShaderAstValueCategory::Visit(CastExpression& /*node*/) { m_expressionCategory = ExpressionCategory::RValue; diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 44a7de4aa..d5e3422ce 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -47,6 +48,28 @@ namespace Nz::ShaderAst PushScope(); //< Global scope { + RegisterIntrinsic("cross", IntrinsicType::CrossProduct); + RegisterIntrinsic("dot", IntrinsicType::DotProduct); + RegisterIntrinsic("max", IntrinsicType::Max); + RegisterIntrinsic("min", IntrinsicType::Min); + RegisterIntrinsic("length", IntrinsicType::Length); + + // Collect function name and their types + if (nodePtr->GetType() == NodeType::MultiStatement) + { + std::size_t functionIndex = 0; + + const MultiStatement& multiStatement = static_cast(*nodePtr); + for (const auto& statementPtr : multiStatement.statements) + { + if (statementPtr->GetType() == NodeType::DeclareFunctionStatement) + { + const DeclareFunctionStatement& funcDeclaration = static_cast(*statementPtr); + m_functionDeclarations.emplace(funcDeclaration.name, std::make_pair(&funcDeclaration, functionIndex++)); + } + } + } + try { clone = AstCloner::Clone(nodePtr); @@ -355,6 +378,71 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr SanitizeVisitor::Clone(CallFunctionExpression& node) + { + constexpr std::size_t NoFunction = std::numeric_limits::max(); + + auto clone = std::make_unique(); + + clone->parameters.reserve(node.parameters.size()); + for (std::size_t i = 0; i < node.parameters.size(); ++i) + clone->parameters.push_back(CloneExpression(node.parameters[i])); + + const DeclareFunctionStatement* referenceFunctionDeclaration; + if (std::holds_alternative(node.targetFunction)) + { + const std::string& functionName = std::get(node.targetFunction); + + const Identifier* identifier = FindIdentifier(functionName); + if (identifier) + { + if (identifier->type == Identifier::Type::Intrinsic) + { + // Intrinsic function call + std::vector parameters; + parameters.reserve(node.parameters.size()); + + for (const auto& param : node.parameters) + parameters.push_back(CloneExpression(param)); + + auto intrinsic = ShaderBuilder::Intrinsic(m_intrinsics[identifier->index], std::move(parameters)); + Validate(*intrinsic); + + return intrinsic; + } + else + { + // Regular function call + if (identifier->type != Identifier::Type::Function) + throw AstError{ "function expected" }; + + clone->targetFunction = identifier->index; + referenceFunctionDeclaration = m_functions[identifier->index]; + } + } + else + { + // Identifier not found, maybe the function is declared later + auto it = m_functionDeclarations.find(functionName); + if (it == m_functionDeclarations.end()) + throw AstError{ "function " + functionName + " does not exist" }; + + clone->targetFunction = it->second.second; + + referenceFunctionDeclaration = it->second.first; + } + } + else + { + std::size_t funcIndex = std::get(node.targetFunction); + referenceFunctionDeclaration = m_functions[funcIndex]; + } + + Validate(*clone, referenceFunctionDeclaration); + + return clone; + } + ExpressionPtr SanitizeVisitor::Clone(CastExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); @@ -426,15 +514,13 @@ namespace Nz::ShaderAst if (!identifier) throw AstError{ "unknown identifier " + node.identifier }; - if (!std::holds_alternative(identifier->value)) + if (identifier->type != Identifier::Type::Variable) throw AstError{ "expected variable identifier" }; - const Variable& variable = std::get(identifier->value); - // Replace IdentifierExpression by VariableExpression auto varExpr = std::make_unique(); - varExpr->cachedExpressionType = m_variables[variable.varIndex]; - varExpr->variableId = variable.varIndex; + varExpr->cachedExpressionType = m_variableTypes[identifier->index]; + varExpr->variableId = identifier->index; return varExpr; } @@ -442,110 +528,7 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) { auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - - // Parameter validation - switch (clone->intrinsic) - { - case IntrinsicType::CrossProduct: - case IntrinsicType::DotProduct: - case IntrinsicType::Max: - case IntrinsicType::Min: - { - if (clone->parameters.size() != 2) - throw AstError { "Expected two parameters" }; - - for (auto& param : clone->parameters) - MandatoryExpr(param); - - const ExpressionType& type = GetExpressionType(*clone->parameters.front()); - - for (std::size_t i = 1; i < clone->parameters.size(); ++i) - { - if (type != GetExpressionType(*clone->parameters[i])) - throw AstError{ "All type must match" }; - } - - break; - } - - case IntrinsicType::Length: - { - if (clone->parameters.size() != 1) - throw AstError{ "Expected only one parameters" }; - - for (auto& param : clone->parameters) - MandatoryExpr(param); - - const ExpressionType& type = GetExpressionType(*clone->parameters.front()); - if (!IsVectorType(type)) - throw AstError{ "Expected a vector" }; - - break; - } - - case IntrinsicType::SampleTexture: - { - if (clone->parameters.size() != 2) - throw AstError{ "Expected two parameters" }; - - for (auto& param : clone->parameters) - MandatoryExpr(param); - - if (!IsSamplerType(GetExpressionType(*clone->parameters[0]))) - throw AstError{ "First parameter must be a sampler" }; - - if (!IsVectorType(GetExpressionType(*clone->parameters[1]))) - throw AstError{ "Second parameter must be a vector" }; - - break; - } - } - - // Return type attribution - switch (clone->intrinsic) - { - case IntrinsicType::CrossProduct: - { - const ExpressionType& type = GetExpressionType(*clone->parameters.front()); - if (type != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } }) - throw AstError{ "CrossProduct only works with vec3 expressions" }; - - clone->cachedExpressionType = type; - break; - } - - case IntrinsicType::DotProduct: - case IntrinsicType::Length: - { - ExpressionType type = GetExpressionType(*clone->parameters.front()); - if (!IsVectorType(type)) - throw AstError{ "DotProduct expects vector types" }; - - clone->cachedExpressionType = std::get(type).type; - break; - } - - case IntrinsicType::Max: - case IntrinsicType::Min: - { - const ExpressionType& type = GetExpressionType(*clone->parameters.front()); - if (!IsPrimitiveType(type) && !IsVectorType(type)) - throw AstError{ "max and min only work with primitive and vector types" }; - - if ((IsPrimitiveType(type) && std::get(type) == PrimitiveType::Boolean) || - (IsVectorType(type) && std::get(type).type == PrimitiveType::Boolean)) - throw AstError{ "max and min do not work with booleans" }; - - clone->cachedExpressionType = type; - break; - } - - case IntrinsicType::SampleTexture: - { - clone->cachedExpressionType = VectorType{ 4, std::get(GetExpressionType(*clone->parameters.front())).sampledType }; - break; - } - } + Validate(*clone); return clone; } @@ -563,10 +546,10 @@ namespace Nz::ShaderAst if (!identifier) throw AstError{ "unknown option " + node.optionName }; - if (!std::holds_alternative