diff --git a/examples/RenderTest/main.cpp b/examples/RenderTest/main.cpp index 23a2c8cf9..5dcde7161 100644 --- a/examples/RenderTest/main.cpp +++ b/examples/RenderTest/main.cpp @@ -45,8 +45,8 @@ struct FragOut [[entry(frag)]] fn main(fragIn: VertOut) -> FragOut { - let lightDir: vec3 = vec3(0.0, -0.707, 0.707); - let lightFactor: f32 = dot(fragIn.normal, lightDir); + let lightDir = vec3(0.0, -0.707, 0.707); + let lightFactor = dot(fragIn.normal, lightDir); let fragOut: FragOut; fragOut.color = lightFactor * tex.Sample(fragIn.uv); diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index 9556eee0f..14a491682 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -49,6 +49,7 @@ namespace Nz::ShaderAst void Visit(DeclareExternalStatement& node) override; void Visit(DeclareFunctionStatement& node) override; void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; void Visit(ExpressionStatement& node) override; void Visit(MultiStatement& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index ed6b3f624..5f01128e3 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -70,6 +70,7 @@ namespace Nz::ShaderBuilder struct DeclareVariable { + inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const; inline std::unique_ptr operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const; }; diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index a034d35f9..f72c98dbe 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -128,6 +128,15 @@ namespace Nz::ShaderBuilder return declareStructNode; } + inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const + { + auto declareVariableNode = std::make_unique(); + declareVariableNode->varName = std::move(name); + declareVariableNode->initialExpression = std::move(initialValue); + + return declareVariableNode; + } + inline std::unique_ptr Nz::ShaderBuilder::Impl::DeclareVariable::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const { auto declareVariableNode = std::make_unique(); diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index ed7dcff50..c979ff4b0 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -594,6 +594,21 @@ namespace Nz::ShaderAst AstScopedVisitor::Visit(node); } + void AstValidator::Visit(DeclareVariableStatement& node) + { + if (IsNoType(node.varType)) + { + if (!node.initialExpression) + throw AstError{ "variable must either have a type or an initial value" }; + + node.initialExpression->Visit(*this); + + node.varType = GetExpressionType(*node.initialExpression); + } + + AstScopedVisitor::Visit(node); + } + void AstValidator::Visit(ExpressionStatement& node) { MandatoryExpr(node.expression); diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index b0636f7be..6df94d266 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -637,15 +637,19 @@ namespace Nz::ShaderLang std::string variableName = ParseIdentifierAsName(); RegisterVariable(variableName); + + ShaderAst::ExpressionType variableType = ShaderAst::NoType{}; + if (Peek().type == TokenType::Colon) + { + Expect(Advance(), TokenType::Colon); - Expect(Advance(), TokenType::Colon); - - ShaderAst::ExpressionType variableType = ParseType(); + variableType = ParseType(); + } ShaderAst::ExpressionPtr expression; - if (Peek().type == TokenType::Assign) + if (IsNoType(variableType) || Peek().type == TokenType::Assign) { - Consume(); + Expect(Advance(), TokenType::Assign); expression = ParseExpression(); }