diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index a74871da4..5a6f7049d 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -83,6 +83,7 @@ namespace Nz void Visit(ShaderNodes::LocalVariable& var) override; void Visit(ShaderNodes::NoOp& node) override; void Visit(ShaderNodes::ParameterVariable& var) override; + void Visit(ShaderNodes::ReturnStatement& node) override; void Visit(ShaderNodes::OutputVariable& var) override; void Visit(ShaderNodes::Sample2D& node) override; void Visit(ShaderNodes::StatementBlock& node) override; diff --git a/include/Nazara/Shader/ShaderAst.hpp b/include/Nazara/Shader/ShaderAst.hpp index 54199a179..f34b0ac9f 100644 --- a/include/Nazara/Shader/ShaderAst.hpp +++ b/include/Nazara/Shader/ShaderAst.hpp @@ -35,7 +35,7 @@ namespace Nz ~ShaderAst() = default; void AddCondition(std::string name); - void AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters = {}, ShaderNodes::BasicType returnType = ShaderNodes::BasicType::Void); + void AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters = {}, ShaderExpressionType returnType = ShaderNodes::BasicType::Void); void AddInput(std::string name, ShaderExpressionType type, std::optional locationIndex = {}); void AddOutput(std::string name, ShaderExpressionType type, std::optional locationIndex = {}); void AddStruct(std::string name, std::vector members); @@ -85,7 +85,7 @@ namespace Nz { std::string name; std::vector parameters; - ShaderNodes::BasicType returnType; + ShaderExpressionType returnType; ShaderNodes::StatementPtr statement; }; diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index 94b68cb0a..b777e2ba6 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -48,6 +48,7 @@ namespace Nz void Visit(ShaderNodes::Identifier& node) override; void Visit(ShaderNodes::IntrinsicCall& node) override; void Visit(ShaderNodes::NoOp& node) override; + void Visit(ShaderNodes::ReturnStatement& node) override; void Visit(ShaderNodes::Sample2D& node) override; void Visit(ShaderNodes::StatementBlock& node) override; void Visit(ShaderNodes::SwizzleOp& node) override; diff --git a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp index ce8ba487c..1b7e1eebd 100644 --- a/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstRecursiveVisitor.hpp @@ -35,6 +35,7 @@ namespace Nz void Visit(ShaderNodes::Identifier& node) override; void Visit(ShaderNodes::IntrinsicCall& node) override; void Visit(ShaderNodes::NoOp& node) override; + void Visit(ShaderNodes::ReturnStatement& node) override; void Visit(ShaderNodes::Sample2D& node) override; void Visit(ShaderNodes::StatementBlock& node) override; void Visit(ShaderNodes::SwizzleOp& node) override; diff --git a/include/Nazara/Shader/ShaderAstSerializer.hpp b/include/Nazara/Shader/ShaderAstSerializer.hpp index ad8af504d..b3e674fae 100644 --- a/include/Nazara/Shader/ShaderAstSerializer.hpp +++ b/include/Nazara/Shader/ShaderAstSerializer.hpp @@ -41,6 +41,7 @@ namespace Nz void Serialize(ShaderNodes::IntrinsicCall& node); void Serialize(ShaderNodes::NamedVariable& var); void Serialize(ShaderNodes::NoOp& node); + void Serialize(ShaderNodes::ReturnStatement& node); void Serialize(ShaderNodes::Sample2D& node); void Serialize(ShaderNodes::StatementBlock& node); void Serialize(ShaderNodes::SwizzleOp& node); diff --git a/include/Nazara/Shader/ShaderAstValidator.hpp b/include/Nazara/Shader/ShaderAstValidator.hpp index 579893aa4..90ec82c71 100644 --- a/include/Nazara/Shader/ShaderAstValidator.hpp +++ b/include/Nazara/Shader/ShaderAstValidator.hpp @@ -48,6 +48,7 @@ namespace Nz void Visit(ShaderNodes::ExpressionStatement& node) override; void Visit(ShaderNodes::Identifier& node) override; void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::ReturnStatement& node) override; void Visit(ShaderNodes::Sample2D& node) override; void Visit(ShaderNodes::StatementBlock& node) override; void Visit(ShaderNodes::SwizzleOp& node) override; diff --git a/include/Nazara/Shader/ShaderAstVisitor.hpp b/include/Nazara/Shader/ShaderAstVisitor.hpp index 79fb23e9e..183a58f92 100644 --- a/include/Nazara/Shader/ShaderAstVisitor.hpp +++ b/include/Nazara/Shader/ShaderAstVisitor.hpp @@ -36,6 +36,7 @@ namespace Nz virtual void Visit(ShaderNodes::Identifier& node) = 0; virtual void Visit(ShaderNodes::IntrinsicCall& node) = 0; virtual void Visit(ShaderNodes::NoOp& node) = 0; + virtual void Visit(ShaderNodes::ReturnStatement& node) = 0; virtual void Visit(ShaderNodes::Sample2D& node) = 0; virtual void Visit(ShaderNodes::StatementBlock& node) = 0; virtual void Visit(ShaderNodes::SwizzleOp& node) = 0; diff --git a/include/Nazara/Shader/ShaderAstVisitorExcept.hpp b/include/Nazara/Shader/ShaderAstVisitorExcept.hpp index 65e507062..5076284fe 100644 --- a/include/Nazara/Shader/ShaderAstVisitorExcept.hpp +++ b/include/Nazara/Shader/ShaderAstVisitorExcept.hpp @@ -31,6 +31,7 @@ namespace Nz void Visit(ShaderNodes::Identifier& node) override; void Visit(ShaderNodes::IntrinsicCall& node) override; void Visit(ShaderNodes::NoOp& node) override; + void Visit(ShaderNodes::ReturnStatement& node) override; void Visit(ShaderNodes::Sample2D& node) override; void Visit(ShaderNodes::StatementBlock& node) override; void Visit(ShaderNodes::SwizzleOp& node) override; diff --git a/include/Nazara/Shader/ShaderEnums.hpp b/include/Nazara/Shader/ShaderEnums.hpp index 64d8729f3..87f2f4cd7 100644 --- a/include/Nazara/Shader/ShaderEnums.hpp +++ b/include/Nazara/Shader/ShaderEnums.hpp @@ -91,6 +91,7 @@ namespace Nz::ShaderNodes Identifier, IntrinsicCall, NoOp, + ReturnStatement, Sample2D, SwizzleOp, StatementBlock, diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index be6da0a35..6b3d296a3 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace Nz::ShaderLang { @@ -19,6 +20,18 @@ namespace Nz::ShaderLang using exception::exception; }; + class ReservedKeyword : public std::exception + { + public: + using exception::exception; + }; + + class UnknownType : public std::exception + { + public: + using exception::exception; + }; + class UnexpectedToken : public std::exception { public: @@ -31,19 +44,39 @@ namespace Nz::ShaderLang inline Parser(); ~Parser() = default; - void Parse(const std::vector& tokens); + ShaderAst Parse(const std::vector& tokens); private: + // Flow control const Token& Advance(); void Expect(const Token& token, TokenType type); - void ExpectNext(TokenType type); - void ParseFunctionBody(); - void ParseFunctionDeclaration(); - void ParseFunctionParameter(); + const Token& ExpectNext(TokenType type); const Token& PeekNext(); + // Statements + ShaderNodes::StatementPtr ParseFunctionBody(); + void ParseFunctionDeclaration(); + ShaderAst::FunctionParameter ParseFunctionParameter(); + ShaderNodes::StatementPtr ParseReturnStatement(); + ShaderNodes::StatementPtr ParseStatement(); + ShaderNodes::StatementPtr ParseStatementList(); + + // Expressions + ShaderNodes::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs); + ShaderNodes::ExpressionPtr ParseExpression(); + ShaderNodes::ExpressionPtr ParseIdentifier(); + ShaderNodes::ExpressionPtr ParseIntegerExpression(); + ShaderNodes::ExpressionPtr ParseParenthesisExpression(); + ShaderNodes::ExpressionPtr ParsePrimaryExpression(); + + std::string ParseIdentifierAsName(); + ShaderExpressionType ParseIdentifierAsType(); + + static int GetTokenPrecedence(TokenType token); + struct Context { + ShaderAst result; std::size_t tokenCount; std::size_t tokenIndex = 0; const Token* tokens; diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index b968c590e..caaa2964d 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -10,7 +10,6 @@ #define NAZARA_SHADERLANG_TOKEN_LAST(X) NAZARA_SHADERLANG_TOKEN(X) #endif -NAZARA_SHADERLANG_TOKEN(Add) NAZARA_SHADERLANG_TOKEN(BoolFalse) NAZARA_SHADERLANG_TOKEN(BoolTrue) NAZARA_SHADERLANG_TOKEN(ClosingParenthesis) @@ -26,11 +25,12 @@ NAZARA_SHADERLANG_TOKEN(FunctionReturn) NAZARA_SHADERLANG_TOKEN(IntegerValue) NAZARA_SHADERLANG_TOKEN(Identifier) NAZARA_SHADERLANG_TOKEN(Multiply) +NAZARA_SHADERLANG_TOKEN(Minus) +NAZARA_SHADERLANG_TOKEN(Plus) NAZARA_SHADERLANG_TOKEN(OpenCurlyBracket) NAZARA_SHADERLANG_TOKEN(OpenParenthesis) NAZARA_SHADERLANG_TOKEN(Semicolon) NAZARA_SHADERLANG_TOKEN(Return) -NAZARA_SHADERLANG_TOKEN(Subtract) #undef NAZARA_SHADERLANG_TOKEN #undef NAZARA_SHADERLANG_TOKEN_LAST diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index e2e082b97..0d1d57183 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -168,6 +168,17 @@ namespace Nz static inline std::shared_ptr Build(); }; + struct NAZARA_SHADER_API ReturnStatement : public Statement + { + inline ReturnStatement(); + + void Visit(ShaderAstVisitor& visitor) override; + + ExpressionPtr returnExpr; + + static inline std::shared_ptr Build(ExpressionPtr expr = nullptr); + }; + ////////////////////////////////////////////////////////////////////////// struct NAZARA_SHADER_API AssignOp : public Expression diff --git a/include/Nazara/Shader/ShaderNodes.inl b/include/Nazara/Shader/ShaderNodes.inl index d050a4a31..c1e2b22c3 100644 --- a/include/Nazara/Shader/ShaderNodes.inl +++ b/include/Nazara/Shader/ShaderNodes.inl @@ -194,7 +194,7 @@ namespace Nz::ShaderNodes } - inline ShaderNodes::NoOp::NoOp() : + inline NoOp::NoOp() : Statement(NodeType::NoOp) { } @@ -205,6 +205,20 @@ namespace Nz::ShaderNodes } + inline ReturnStatement::ReturnStatement() : + Statement(NodeType::ReturnStatement) + { + } + + inline std::shared_ptr ShaderNodes::ReturnStatement::Build(ExpressionPtr expr) + { + auto node = std::make_shared(); + node->returnExpr = std::move(expr); + + return node; + } + + inline AssignOp::AssignOp() : Expression(NodeType::AssignOp) { diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 0cac761ca..cdf6fa2fe 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -43,6 +43,7 @@ namespace Nz void Visit(ShaderNodes::Identifier& node) override; void Visit(ShaderNodes::IntrinsicCall& node) override; void Visit(ShaderNodes::NoOp& node) override; + void Visit(ShaderNodes::ReturnStatement& node) override; void Visit(ShaderNodes::Sample2D& node) override; void Visit(ShaderNodes::StatementBlock& node) override; void Visit(ShaderNodes::SwizzleOp& node) override; diff --git a/include/Nazara/Shader/SpirvBlock.hpp b/include/Nazara/Shader/SpirvBlock.hpp index 5fbbe0843..6c6a9310d 100644 --- a/include/Nazara/Shader/SpirvBlock.hpp +++ b/include/Nazara/Shader/SpirvBlock.hpp @@ -8,14 +8,14 @@ #define NAZARA_SPIRVBLOCK_HPP #include -#include +#include #include #include #include namespace Nz { - class NAZARA_SHADER_API SpirvBlock : public SpirvSection + class NAZARA_SHADER_API SpirvBlock : public SpirvSectionBase { public: inline SpirvBlock(SpirvWriter& writer); @@ -23,13 +23,24 @@ namespace Nz SpirvBlock(SpirvBlock&&) = default; ~SpirvBlock() = default; + inline std::size_t Append(SpirvOp opcode, const OpSize& wordCount); + template std::size_t Append(SpirvOp opcode, Args&&... args); + template std::size_t AppendVariadic(SpirvOp opcode, F&& callback); + inline UInt32 GetLabelId() const; + inline bool IsTerminated() const; + SpirvBlock& operator=(const SpirvBlock&) = delete; SpirvBlock& operator=(SpirvBlock&&) = default; + static inline bool IsTerminationInstruction(SpirvOp op); + private: + inline void HandleSpirvOp(SpirvOp op); + UInt32 m_labelId; + bool m_isTerminated; }; } diff --git a/include/Nazara/Shader/SpirvBlock.inl b/include/Nazara/Shader/SpirvBlock.inl index 3ce1f3050..2dc50e76b 100644 --- a/include/Nazara/Shader/SpirvBlock.inl +++ b/include/Nazara/Shader/SpirvBlock.inl @@ -7,16 +7,70 @@ namespace Nz { - inline SpirvBlock::SpirvBlock(SpirvWriter& writer) + inline SpirvBlock::SpirvBlock(SpirvWriter& writer) : + m_isTerminated(false) { m_labelId = writer.AllocateResultId(); Append(SpirvOp::OpLabel, m_labelId); } + inline std::size_t SpirvBlock::Append(SpirvOp opcode, const OpSize& wordCount) + { + HandleSpirvOp(opcode); + + return SpirvSectionBase::Append(opcode, wordCount); + } + + template + std::size_t SpirvBlock::Append(SpirvOp opcode, Args&&... args) + { + HandleSpirvOp(opcode); + + return SpirvSectionBase::Append(opcode, std::forward(args)...); + } + + template + std::size_t SpirvBlock::AppendVariadic(SpirvOp opcode, F&& callback) + { + HandleSpirvOp(opcode); + + return SpirvSectionBase::AppendVariadic(opcode, std::forward(callback)); + } + inline UInt32 SpirvBlock::GetLabelId() const { return m_labelId; } + + inline bool SpirvBlock::IsTerminated() const + { + return m_isTerminated; + } + + inline bool SpirvBlock::IsTerminationInstruction(SpirvOp op) + { + switch (op) + { + case SpirvOp::OpBranch: + case SpirvOp::OpBranchConditional: + case SpirvOp::OpKill: + case SpirvOp::OpReturn: + case SpirvOp::OpReturnValue: + case SpirvOp::OpSwitch: + case SpirvOp::OpUnreachable: + return true; + + default: + return false; + } + } + + inline void SpirvBlock::HandleSpirvOp(SpirvOp op) + { + assert(!m_isTerminated); + if (IsTerminationInstruction(op)) + m_isTerminated = true; + } } #include diff --git a/include/Nazara/Shader/SpirvExpressionLoad.hpp b/include/Nazara/Shader/SpirvExpressionLoad.hpp index bb44bba71..f59369a52 100644 --- a/include/Nazara/Shader/SpirvExpressionLoad.hpp +++ b/include/Nazara/Shader/SpirvExpressionLoad.hpp @@ -36,6 +36,7 @@ namespace Nz using ShaderVarVisitor::Visit; void Visit(ShaderNodes::InputVariable& var) override; void Visit(ShaderNodes::LocalVariable& var) override; + void Visit(ShaderNodes::ParameterVariable& var) override; void Visit(ShaderNodes::UniformVariable& var) override; SpirvExpressionLoad& operator=(const SpirvExpressionLoad&) = delete; diff --git a/include/Nazara/Shader/SpirvPrinter.hpp b/include/Nazara/Shader/SpirvPrinter.hpp index fbfe0eeb2..c7c7815f7 100644 --- a/include/Nazara/Shader/SpirvPrinter.hpp +++ b/include/Nazara/Shader/SpirvPrinter.hpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace Nz { @@ -23,7 +24,9 @@ namespace Nz SpirvPrinter(SpirvPrinter&&) = default; ~SpirvPrinter() = default; + inline std::string Print(const std::vector& codepoints); inline std::string Print(const UInt32* codepoints, std::size_t count); + inline std::string Print(const std::vector& codepoints, const Settings& settings); std::string Print(const UInt32* codepoints, std::size_t count, const Settings& settings); SpirvPrinter& operator=(const SpirvPrinter&) = default; diff --git a/include/Nazara/Shader/SpirvPrinter.inl b/include/Nazara/Shader/SpirvPrinter.inl index 1435293b0..866cc096f 100644 --- a/include/Nazara/Shader/SpirvPrinter.inl +++ b/include/Nazara/Shader/SpirvPrinter.inl @@ -12,11 +12,21 @@ namespace Nz { } + inline std::string SpirvPrinter::Print(const std::vector& codepoints) + { + return Print(codepoints.data(), codepoints.size()); + } + inline std::string SpirvPrinter::Print(const UInt32* codepoints, std::size_t count) { Settings settings; return Print(codepoints, count, settings); } + + inline std::string SpirvPrinter::Print(const std::vector& codepoints, const Settings& settings) + { + return Print(codepoints.data(), codepoints.size(), settings); + } } #include diff --git a/include/Nazara/Shader/SpirvSection.hpp b/include/Nazara/Shader/SpirvSection.hpp index b759cbe2e..498b4ec68 100644 --- a/include/Nazara/Shader/SpirvSection.hpp +++ b/include/Nazara/Shader/SpirvSection.hpp @@ -8,64 +8,25 @@ #define NAZARA_SPIRVSECTION_HPP #include -#include -#include -#include -#include +#include namespace Nz { - class NAZARA_SHADER_API SpirvSection + class NAZARA_SHADER_API SpirvSection : public SpirvSectionBase { public: - struct OpSize; - struct Raw; - SpirvSection() = default; SpirvSection(const SpirvSection&) = default; SpirvSection(SpirvSection&&) = default; ~SpirvSection() = default; - inline std::size_t Append(const char* str); - inline std::size_t Append(const std::string_view& str); - inline std::size_t Append(const std::string& str); - inline std::size_t Append(UInt32 value); - inline std::size_t Append(SpirvOp opcode, const OpSize& wordCount); - std::size_t Append(const Raw& raw); - inline std::size_t Append(std::initializer_list codepoints); - template std::size_t Append(SpirvOp opcode, const Args&... args); - template std::size_t AppendVariadic(SpirvOp opcode, F&& callback); - inline std::size_t Append(const SpirvSection& section); - template || std::is_enum_v>> std::size_t Append(T value); - - inline unsigned int CountWord(const char* str); - inline unsigned int CountWord(const std::string_view& str); - inline unsigned int CountWord(const std::string& str); - inline unsigned int CountWord(const Raw& raw); - template || std::is_enum_v>> unsigned int CountWord(const T& value); - template unsigned int CountWord(const T1& value, const T2& value2, const Args&... rest); - - inline const std::vector& GetBytecode() const; - inline std::size_t GetOutputOffset() const; + using SpirvSectionBase::Append; + using SpirvSectionBase::AppendRaw; + using SpirvSectionBase::AppendSection; + using SpirvSectionBase::AppendVariadic; SpirvSection& operator=(const SpirvSection&) = delete; SpirvSection& operator=(SpirvSection&&) = default; - - struct OpSize - { - unsigned int wc; - }; - - struct Raw - { - const void* ptr; - std::size_t size; - }; - - static inline UInt32 BuildOpcode(SpirvOp opcode, unsigned int wordCount); - - private: - std::vector m_bytecode; }; } diff --git a/include/Nazara/Shader/SpirvSection.inl b/include/Nazara/Shader/SpirvSection.inl index e6a400d06..941659162 100644 --- a/include/Nazara/Shader/SpirvSection.inl +++ b/include/Nazara/Shader/SpirvSection.inl @@ -7,151 +7,6 @@ namespace Nz { - inline std::size_t SpirvSection::Append(const char* str) - { - return Append(std::string_view(str)); - } - - inline std::size_t SpirvSection::Append(const std::string_view& str) - { - std::size_t offset = GetOutputOffset(); - - std::size_t size4 = CountWord(str); - for (std::size_t i = 0; i < size4; ++i) - { - UInt32 codepoint = 0; - for (std::size_t j = 0; j < 4; ++j) - { - std::size_t pos = i * 4 + j; - if (pos < str.size()) - codepoint |= UInt32(str[pos]) << (j * 8); - } - - Append(codepoint); - } - - return offset; - } - - inline std::size_t SpirvSection::Append(const std::string& str) - { - return Append(std::string_view(str)); - } - - inline std::size_t SpirvSection::Append(UInt32 value) - { - std::size_t offset = GetOutputOffset(); - m_bytecode.push_back(value); - - return offset; - } - - inline std::size_t SpirvSection::Append(SpirvOp opcode, const OpSize& wordCount) - { - return Append(BuildOpcode(opcode, wordCount.wc)); - } - - inline std::size_t SpirvSection::Append(std::initializer_list codepoints) - { - std::size_t offset = GetOutputOffset(); - - for (UInt32 cp : codepoints) - Append(cp); - - return offset; - } - - inline std::size_t SpirvSection::Append(const SpirvSection& section) - { - const std::vector& bytecode = section.GetBytecode(); - - std::size_t offset = GetOutputOffset(); - m_bytecode.resize(offset + bytecode.size()); - std::copy(bytecode.begin(), bytecode.end(), m_bytecode.begin() + offset); - - return offset; - } - - template - std::size_t SpirvSection::Append(SpirvOp opcode, const Args&... args) - { - unsigned int wordCount = 1 + (CountWord(args) + ... + 0); - std::size_t offset = Append(opcode, OpSize{ wordCount }); - if constexpr (sizeof...(args) > 0) - (Append(args), ...); - - return offset; - } - - template std::size_t SpirvSection::AppendVariadic(SpirvOp opcode, F&& callback) - { - std::size_t offset = Append(0); //< Will be filled later - - unsigned int wordCount = 1; - auto appendFunctor = [&](const auto& value) - { - wordCount += CountWord(value); - Append(value); - }; - callback(appendFunctor); - - m_bytecode[offset] = BuildOpcode(opcode, wordCount); - - return offset; - } - - template - std::size_t SpirvSection::Append(T value) - { - return Append(static_cast(value)); - } - - template - unsigned int SpirvSection::CountWord(const T& /*value*/) - { - return 1; - } - - template - unsigned int SpirvSection::CountWord(const T1& value, const T2& value2, const Args&... rest) - { - return CountWord(value) + CountWord(value2) + (CountWord(rest) + ...); - } - - inline unsigned int SpirvSection::CountWord(const char* str) - { - return CountWord(std::string_view(str)); - } - - inline unsigned int Nz::SpirvSection::CountWord(const std::string& str) - { - return CountWord(std::string_view(str)); - } - - inline unsigned int SpirvSection::CountWord(const Raw& raw) - { - return static_cast((raw.size + sizeof(UInt32) - 1) / sizeof(UInt32)); - } - - inline unsigned int SpirvSection::CountWord(const std::string_view& str) - { - return (static_cast(str.size() + 1) + sizeof(UInt32) - 1) / sizeof(UInt32); //< + 1 for null character - } - - inline const std::vector& SpirvSection::GetBytecode() const - { - return m_bytecode; - } - - inline std::size_t SpirvSection::GetOutputOffset() const - { - return m_bytecode.size(); - } - - inline UInt32 SpirvSection::BuildOpcode(SpirvOp opcode, unsigned int wordCount) - { - return UInt32(opcode) | UInt32(wordCount) << 16; - } } #include diff --git a/include/Nazara/Shader/SpirvSectionBase.hpp b/include/Nazara/Shader/SpirvSectionBase.hpp new file mode 100644 index 000000000..4783fd80b --- /dev/null +++ b/include/Nazara/Shader/SpirvSectionBase.hpp @@ -0,0 +1,75 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVSECTIONBASE_HPP +#define NAZARA_SPIRVSECTIONBASE_HPP + +#include +#include +#include +#include +#include + +namespace Nz +{ + class NAZARA_SHADER_API SpirvSectionBase + { + public: + struct OpSize; + struct Raw; + + SpirvSectionBase() = default; + SpirvSectionBase(const SpirvSectionBase&) = default; + SpirvSectionBase(SpirvSectionBase&&) = default; + ~SpirvSectionBase() = default; + + inline const std::vector& GetBytecode() const; + inline std::size_t GetOutputOffset() const; + + SpirvSectionBase& operator=(const SpirvSectionBase&) = delete; + SpirvSectionBase& operator=(SpirvSectionBase&&) = default; + + struct OpSize + { + unsigned int wc; + }; + + struct Raw + { + const void* ptr; + std::size_t size; + }; + + static inline UInt32 BuildOpcode(SpirvOp opcode, unsigned int wordCount); + + protected: + inline std::size_t Append(SpirvOp opcode, const OpSize& wordCount); + template std::size_t Append(SpirvOp opcode, const Args&... args); + template std::size_t AppendVariadic(SpirvOp opcode, F&& callback); + inline std::size_t AppendRaw(const char* str); + inline std::size_t AppendRaw(const std::string_view& str); + inline std::size_t AppendRaw(const std::string& str); + inline std::size_t AppendRaw(UInt32 value); + std::size_t AppendRaw(const Raw& raw); + inline std::size_t AppendRaw(std::initializer_list codepoints); + inline std::size_t AppendSection(const SpirvSectionBase& section); + template || std::is_enum_v>> std::size_t AppendRaw(T value); + + inline unsigned int CountWord(const char* str); + inline unsigned int CountWord(const std::string_view& str); + inline unsigned int CountWord(const std::string& str); + inline unsigned int CountWord(const Raw& raw); + template || std::is_enum_v>> unsigned int CountWord(const T& value); + template unsigned int CountWord(const T1& value, const T2& value2, const Args&... rest); + + private: + std::vector m_bytecode; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvSectionBase.inl b/include/Nazara/Shader/SpirvSectionBase.inl new file mode 100644 index 000000000..760bfea6b --- /dev/null +++ b/include/Nazara/Shader/SpirvSectionBase.inl @@ -0,0 +1,157 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline std::size_t SpirvSectionBase::Append(SpirvOp opcode, const OpSize& wordCount) + { + return AppendRaw(BuildOpcode(opcode, wordCount.wc)); + } + + template + std::size_t SpirvSectionBase::Append(SpirvOp opcode, const Args&... args) + { + unsigned int wordCount = 1 + (CountWord(args) + ... + 0); + std::size_t offset = Append(opcode, OpSize{ wordCount }); + if constexpr (sizeof...(args) > 0) + (AppendRaw(args), ...); + + return offset; + } + + template std::size_t SpirvSectionBase::AppendVariadic(SpirvOp opcode, F&& callback) + { + std::size_t offset = AppendRaw(0); //< Will be filled later + + unsigned int wordCount = 1; + auto appendFunctor = [&](const auto& value) + { + wordCount += CountWord(value); + AppendRaw(value); + }; + callback(appendFunctor); + + m_bytecode[offset] = BuildOpcode(opcode, wordCount); + + return offset; + } + + inline std::size_t SpirvSectionBase::AppendRaw(const char* str) + { + return AppendRaw(std::string_view(str)); + } + + inline std::size_t SpirvSectionBase::AppendRaw(const std::string_view& str) + { + std::size_t offset = GetOutputOffset(); + + std::size_t size4 = CountWord(str); + for (std::size_t i = 0; i < size4; ++i) + { + UInt32 codepoint = 0; + for (std::size_t j = 0; j < 4; ++j) + { + std::size_t pos = i * 4 + j; + if (pos < str.size()) + codepoint |= UInt32(str[pos]) << (j * 8); + } + + AppendRaw(codepoint); + } + + return offset; + } + + inline std::size_t SpirvSectionBase::AppendRaw(const std::string& str) + { + return AppendRaw(std::string_view(str)); + } + + inline std::size_t SpirvSectionBase::AppendRaw(UInt32 value) + { + std::size_t offset = GetOutputOffset(); + m_bytecode.push_back(value); + + return offset; + } + + inline std::size_t SpirvSectionBase::AppendRaw(std::initializer_list codepoints) + { + std::size_t offset = GetOutputOffset(); + + for (UInt32 cp : codepoints) + AppendRaw(cp); + + return offset; + } + + inline std::size_t SpirvSectionBase::AppendSection(const SpirvSectionBase& section) + { + const std::vector& bytecode = section.GetBytecode(); + + std::size_t offset = GetOutputOffset(); + m_bytecode.resize(offset + bytecode.size()); + std::copy(bytecode.begin(), bytecode.end(), m_bytecode.begin() + offset); + + return offset; + } + + template + std::size_t SpirvSectionBase::AppendRaw(T value) + { + return AppendRaw(static_cast(value)); + } + + template + unsigned int SpirvSectionBase::CountWord(const T& /*value*/) + { + return 1; + } + + template + unsigned int SpirvSectionBase::CountWord(const T1& value, const T2& value2, const Args&... rest) + { + return CountWord(value) + CountWord(value2) + (CountWord(rest) + ...); + } + + inline unsigned int SpirvSectionBase::CountWord(const char* str) + { + return CountWord(std::string_view(str)); + } + + inline unsigned int Nz::SpirvSectionBase::CountWord(const std::string& str) + { + return CountWord(std::string_view(str)); + } + + inline unsigned int SpirvSectionBase::CountWord(const Raw& raw) + { + return static_cast((raw.size + sizeof(UInt32) - 1) / sizeof(UInt32)); + } + + inline unsigned int SpirvSectionBase::CountWord(const std::string_view& str) + { + return (static_cast(str.size() + 1) + sizeof(UInt32) - 1) / sizeof(UInt32); //< + 1 for null character + } + + inline const std::vector& SpirvSectionBase::GetBytecode() const + { + return m_bytecode; + } + + inline std::size_t SpirvSectionBase::GetOutputOffset() const + { + return m_bytecode.size(); + } + + inline UInt32 SpirvSectionBase::BuildOpcode(SpirvOp opcode, unsigned int wordCount) + { + return UInt32(opcode) | UInt32(wordCount) << 16; + } +} + +#include diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index 37ebc6ea0..eeaee1e60 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -57,6 +57,8 @@ namespace Nz void AppendHeader(); + SpirvConstantCache::Function BuildFunctionType(ShaderExpressionType retType, const std::vector& parameters); + UInt32 GetConstantId(const ShaderConstantValue& value) const; UInt32 GetFunctionTypeId(ShaderExpressionType retType, const std::vector& parameters); const ExtVar& GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const; @@ -72,6 +74,8 @@ namespace Nz std::optional ReadInputVariable(const std::string& name, OnlyCache); UInt32 ReadLocalVariable(const std::string& name); std::optional ReadLocalVariable(const std::string& name, OnlyCache); + UInt32 ReadParameterVariable(const std::string& name); + std::optional ReadParameterVariable(const std::string& name, OnlyCache); UInt32 ReadUniformVariable(const std::string& name); std::optional ReadUniformVariable(const std::string& name, OnlyCache); UInt32 ReadVariable(ExtVar& var); @@ -89,8 +93,8 @@ namespace Nz struct Context { const ShaderAst* shader = nullptr; - const ShaderAst::Function* currentFunction = nullptr; const States* states = nullptr; + std::vector functionBlocks; }; struct ExtVar diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index f78896a3c..1920bf291 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -608,6 +608,18 @@ namespace Nz Append(var.name); } + void GlslWriter::Visit(ShaderNodes::ReturnStatement& node) + { + if (node.returnExpr) + { + Append("return "); + Visit(node.returnExpr); + Append(";"); + } + else + Append("return;"); + } + void GlslWriter::Visit(ShaderNodes::OutputVariable& var) { Append(var.name); diff --git a/src/Nazara/Shader/ShaderAst.cpp b/src/Nazara/Shader/ShaderAst.cpp index 2f43ee23b..ff84cfa55 100644 --- a/src/Nazara/Shader/ShaderAst.cpp +++ b/src/Nazara/Shader/ShaderAst.cpp @@ -13,7 +13,7 @@ namespace Nz conditionEntry.name = std::move(name); } - void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters, ShaderNodes::BasicType returnType) + void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector parameters, ShaderExpressionType returnType) { auto& functionEntry = m_functions.emplace_back(); functionEntry.name = std::move(name); diff --git a/src/Nazara/Shader/ShaderAstCloner.cpp b/src/Nazara/Shader/ShaderAstCloner.cpp index ebbcfa560..1118029df 100644 --- a/src/Nazara/Shader/ShaderAstCloner.cpp +++ b/src/Nazara/Shader/ShaderAstCloner.cpp @@ -142,6 +142,11 @@ namespace Nz PushStatement(ShaderNodes::NoOp::Build()); } + void ShaderAstCloner::Visit(ShaderNodes::ReturnStatement& node) + { + PushStatement(ShaderNodes::ReturnStatement::Build(CloneExpression(node.returnExpr))); + } + void ShaderAstCloner::Visit(ShaderNodes::Sample2D& node) { PushExpression(ShaderNodes::Sample2D::Build(CloneExpression(node.sampler), CloneExpression(node.coordinates))); diff --git a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp index d3b404011..f4d858d91 100644 --- a/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/ShaderAstRecursiveVisitor.cpp @@ -95,6 +95,12 @@ namespace Nz /* Nothing to do */ } + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ReturnStatement& node) + { + if (node.returnExpr) + Visit(node.returnExpr); + } + void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Sample2D& node) { Visit(node.sampler); diff --git a/src/Nazara/Shader/ShaderAstSerializer.cpp b/src/Nazara/Shader/ShaderAstSerializer.cpp index b998d8637..13780339b 100644 --- a/src/Nazara/Shader/ShaderAstSerializer.cpp +++ b/src/Nazara/Shader/ShaderAstSerializer.cpp @@ -92,6 +92,11 @@ namespace Nz Serialize(node); } + void Visit(ShaderNodes::ReturnStatement& node) override + { + Serialize(node); + } + void Visit(ShaderNodes::Sample2D& node) override { Serialize(node); @@ -286,6 +291,11 @@ namespace Nz /* Nothing to do */ } + void ShaderAstSerializerBase::Serialize(ShaderNodes::ReturnStatement& node) + { + Node(node.returnExpr); + } + void ShaderAstSerializerBase::Serialize(ShaderNodes::Sample2D& node) { Node(node.sampler); @@ -391,7 +401,8 @@ namespace Nz m_stream << UInt32(shader.GetFunctionCount()); for (const auto& func : shader.GetFunctions()) { - m_stream << func.name << UInt32(func.returnType); + m_stream << func.name; + SerializeType(func.returnType); m_stream << UInt32(func.parameters.size()); for (const auto& param : func.parameters) @@ -634,11 +645,12 @@ namespace Nz for (UInt32 i = 0; i < funcCount; ++i) { std::string name; - ShaderNodes::BasicType retType; + ShaderExpressionType retType; std::vector parameters; Value(name); - Enum(retType); + Type(retType); + Container(parameters); for (auto& param : parameters) { @@ -653,7 +665,7 @@ namespace Nz ShaderNodes::StatementPtr statement = std::static_pointer_cast(node); - shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), retType); + shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), std::move(retType)); } return shader; @@ -693,6 +705,7 @@ namespace Nz HandleType(Identifier); HandleType(IntrinsicCall); HandleType(NoOp); + HandleType(ReturnStatement); HandleType(Sample2D); HandleType(SwizzleOp); HandleType(StatementBlock); diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index ec2d9a7d7..f081d7481 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -349,6 +349,22 @@ namespace Nz ShaderAstRecursiveVisitor::Visit(node); } + void ShaderAstValidator::Visit(ShaderNodes::ReturnStatement& node) + { + if (m_context->currentFunction->returnType != ShaderExpressionType(ShaderNodes::BasicType::Void)) + { + if (MandatoryExpr(node.returnExpr)->GetExpressionType() != m_context->currentFunction->returnType) + throw AstError{ "Return type doesn't match function return type" }; + } + else + { + if (node.returnExpr) + throw AstError{ "Unexpected expression for return (function doesn't return)" }; + } + + ShaderAstRecursiveVisitor::Visit(node); + } + void ShaderAstValidator::Visit(ShaderNodes::Sample2D& node) { if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D }) diff --git a/src/Nazara/Shader/ShaderAstVisitorExcept.cpp b/src/Nazara/Shader/ShaderAstVisitorExcept.cpp index 7843281d4..b8f8c7d16 100644 --- a/src/Nazara/Shader/ShaderAstVisitorExcept.cpp +++ b/src/Nazara/Shader/ShaderAstVisitorExcept.cpp @@ -78,6 +78,11 @@ namespace Nz throw std::runtime_error("unhandled NoOp node"); } + void ShaderAstVisitorExcept::Visit(ShaderNodes::ReturnStatement& node) + { + throw std::runtime_error("unhandled ReturnStatement node"); + } + void ShaderAstVisitorExcept::Visit(ShaderNodes::Sample2D& /*node*/) { throw std::runtime_error("unhandled Sample2D node"); diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index a93cefa85..8782a54fb 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -103,7 +103,7 @@ namespace Nz::ShaderLang break; } - tokenType = TokenType::Subtract; + tokenType = TokenType::Minus; break; } @@ -193,7 +193,7 @@ namespace Nz::ShaderLang char* end; double value = std::strtod(valueStr.c_str(), &end); - if (end != &str[currentPos]) + if (end != &str[currentPos + 1]) throw BadNumber{}; token.data = value; @@ -204,7 +204,7 @@ namespace Nz::ShaderLang long long value; std::from_chars_result r = std::from_chars(&str[start], &str[currentPos + 1], value); - if (r.ptr != &str[currentPos]) + if (r.ptr != &str[currentPos + 1]) { if (r.ec == std::errc::result_out_of_range) throw NumberOutOfRange{}; @@ -218,7 +218,7 @@ namespace Nz::ShaderLang break; } - case '+': tokenType = TokenType::Add; break; + case '+': tokenType = TokenType::Plus; break; case '*': tokenType = TokenType::Multiply; break; case ':': tokenType = TokenType::Colon; break; case ';': tokenType = TokenType::Semicolon; break; diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 482afd007..cea55ee2e 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -8,7 +8,33 @@ namespace Nz::ShaderLang { - void Parser::Parse(const std::vector& tokens) + namespace + { + std::unordered_map identifierToBasicType = { + { "bool", ShaderNodes::BasicType::Boolean }, + + { "i32", ShaderNodes::BasicType::Int1 }, + { "vec2i32", ShaderNodes::BasicType::Int2 }, + { "vec3i32", ShaderNodes::BasicType::Int3 }, + { "vec4i32", ShaderNodes::BasicType::Int4 }, + + { "f32", ShaderNodes::BasicType::Float1 }, + { "vec2f32", ShaderNodes::BasicType::Float2 }, + { "vec3f32", ShaderNodes::BasicType::Float3 }, + { "vec4f32", ShaderNodes::BasicType::Float4 }, + + { "mat4x4f32", ShaderNodes::BasicType::Mat4x4 }, + { "sampler2D", ShaderNodes::BasicType::Sampler2D }, + { "void", ShaderNodes::BasicType::Void }, + + { "u32", ShaderNodes::BasicType::UInt1 }, + { "vec2u32", ShaderNodes::BasicType::UInt3 }, + { "vec3u32", ShaderNodes::BasicType::UInt3 }, + { "vec4u32", ShaderNodes::BasicType::UInt4 }, + }; + } + + ShaderAst Parser::Parse(const std::vector& tokens) { Context context; context.tokenCount = tokens.size(); @@ -16,18 +42,28 @@ namespace Nz::ShaderLang m_context = &context; - for (const Token& token : tokens) + m_context->tokenIndex = -1; + + bool reachedEndOfStream = false; + while (!reachedEndOfStream) { - switch (token.type) + const Token& nextToken = PeekNext(); + switch (nextToken.type) { case TokenType::FunctionDeclaration: ParseFunctionDeclaration(); break; + case TokenType::EndOfStream: + reachedEndOfStream = true; + break; + default: throw UnexpectedToken{}; } } + + return std::move(context.result); } const Token& Parser::Advance() @@ -42,24 +78,34 @@ namespace Nz::ShaderLang throw ExpectedToken{}; } - void Parser::ExpectNext(TokenType type) + const Token& Parser::ExpectNext(TokenType type) { - Expect(m_context->tokens[m_context->tokenIndex + 1], type); + const Token& token = Advance(); + Expect(token, type); + + return token; } - void Parser::ParseFunctionBody() + const Token& Parser::PeekNext() { + assert(m_context->tokenIndex + 1 < m_context->tokenCount); + return m_context->tokens[m_context->tokenIndex + 1]; + } + ShaderNodes::StatementPtr Parser::ParseFunctionBody() + { + return ParseStatementList(); } void Parser::ParseFunctionDeclaration() { - ExpectNext(TokenType::Identifier); + ExpectNext(TokenType::FunctionDeclaration); - std::string functionName = std::get(Advance().data); + std::string functionName = ParseIdentifierAsName(); ExpectNext(TokenType::OpenParenthesis); - Advance(); + + std::vector parameters; bool firstParameter = true; for (;;) @@ -74,45 +120,192 @@ namespace Nz::ShaderLang Advance(); } - ParseFunctionParameter(); + parameters.push_back(ParseFunctionParameter()); firstParameter = false; } ExpectNext(TokenType::ClosingParenthesis); - Advance(); + ShaderExpressionType returnType = ShaderNodes::BasicType::Void; if (PeekNext().type == TokenType::FunctionReturn) { - Advance(); + Advance(); //< Consume -> - std::string returnType = std::get(Advance().data); + returnType = ParseIdentifierAsType(); } - ExpectNext(TokenType::OpenCurlyBracket); - Advance(); - ParseFunctionBody(); + ShaderNodes::StatementPtr functionBody = ParseFunctionBody(); ExpectNext(TokenType::ClosingCurlyBracket); - Advance(); + + m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType); } - void Parser::ParseFunctionParameter() + ShaderAst::FunctionParameter Parser::ParseFunctionParameter() { - ExpectNext(TokenType::Identifier); - std::string parameterName = std::get(Advance().data); + std::string parameterName = ParseIdentifierAsName(); ExpectNext(TokenType::Colon); - Advance(); - ExpectNext(TokenType::Identifier); - std::string parameterType = std::get(Advance().data); + ShaderExpressionType parameterType = ParseIdentifierAsType(); + + return { parameterName, parameterType }; } - const Token& Parser::PeekNext() + ShaderNodes::StatementPtr Parser::ParseReturnStatement() { - assert(m_context->tokenIndex + 1 < m_context->tokenCount); - return m_context->tokens[m_context->tokenIndex + 1]; + ExpectNext(TokenType::Return); + + ShaderNodes::ExpressionPtr expr; + if (PeekNext().type != TokenType::Semicolon) + expr = ParseExpression(); + + return ShaderNodes::ReturnStatement::Build(std::move(expr)); + } + + ShaderNodes::StatementPtr Parser::ParseStatement() + { + const Token& token = PeekNext(); + + ShaderNodes::StatementPtr statement; + switch (token.type) + { + case TokenType::Return: + statement = ParseReturnStatement(); + break; + + default: + break; + } + + ExpectNext(TokenType::Semicolon); + + return statement; + } + + ShaderNodes::StatementPtr Parser::ParseStatementList() + { + std::vector statements; + while (PeekNext().type != TokenType::ClosingCurlyBracket) + { + statements.push_back(ParseStatement()); + } + + return ShaderNodes::StatementBlock::Build(std::move(statements)); + } + + ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs) + { + for (;;) + { + const Token& currentOp = PeekNext(); + + int tokenPrecedence = GetTokenPrecedence(currentOp.type); + if (tokenPrecedence < exprPrecedence) + return lhs; + + Advance(); + ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression(); + + const Token& nextOp = PeekNext(); + + int nextTokenPrecedence = GetTokenPrecedence(nextOp.type); + if (tokenPrecedence < nextTokenPrecedence) + rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs)); + + ShaderNodes::BinaryType binaryType; + { + switch (currentOp.type) + { + case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break; + case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break; + case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break; + case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break; + default: throw UnexpectedToken{}; + } + } + + + lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs)); + } + } + + ShaderNodes::ExpressionPtr Parser::ParseExpression() + { + return ParseBinOpRhs(0, ParsePrimaryExpression()); + } + + ShaderNodes::ExpressionPtr Parser::ParseIdentifier() + { + const Token& identifier = ExpectNext(TokenType::Identifier); + + return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get(identifier.data), ShaderNodes::BasicType::Float3)); + } + + ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression() + { + const Token& integer = ExpectNext(TokenType::IntegerValue); + return ShaderNodes::Constant::Build(static_cast(std::get(integer.data))); + } + + ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression() + { + ExpectNext(TokenType::OpenParenthesis); + ShaderNodes::ExpressionPtr expression = ParseExpression(); + ExpectNext(TokenType::ClosingParenthesis); + + return expression; + } + + ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression() + { + const Token& token = PeekNext(); + switch (token.type) + { + case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false); + case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true); + case TokenType::Identifier: return ParseIdentifier(); + case TokenType::IntegerValue: return ParseIntegerExpression(); + case TokenType::OpenParenthesis: return ParseParenthesisExpression(); + default: throw UnexpectedToken{}; + } + } + + std::string Parser::ParseIdentifierAsName() + { + const Token& identifierToken = ExpectNext(TokenType::Identifier); + + std::string identifier = std::get(identifierToken.data); + + auto it = identifierToBasicType.find(identifier); + if (it != identifierToBasicType.end()) + throw ReservedKeyword{}; + + return identifier; + } + + ShaderExpressionType Parser::ParseIdentifierAsType() + { + const Token& identifier = ExpectNext(TokenType::Identifier); + + auto it = identifierToBasicType.find(std::get(identifier.data)); + if (it == identifierToBasicType.end()) + throw UnknownType{}; + + return it->second; + } + + int Parser::GetTokenPrecedence(TokenType token) + { + switch (token) + { + case TokenType::Plus: return 20; + case TokenType::Divide: return 40; + case TokenType::Multiply: return 40; + case TokenType::Minus: return 20; + default: return -1; + } } } diff --git a/src/Nazara/Shader/ShaderNodes.cpp b/src/Nazara/Shader/ShaderNodes.cpp index 0e617297a..97a21506b 100644 --- a/src/Nazara/Shader/ShaderNodes.cpp +++ b/src/Nazara/Shader/ShaderNodes.cpp @@ -69,6 +69,11 @@ namespace Nz::ShaderNodes visitor.Visit(*this); } + void ReturnStatement::Visit(ShaderAstVisitor& visitor) + { + visitor.Visit(*this); + } + ShaderExpressionType AssignOp::GetExpressionType() const { return left->GetExpressionType(); diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index fee8b7d2e..688648e70 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -611,6 +611,14 @@ namespace Nz // nothing to do } + void SpirvAstVisitor::Visit(ShaderNodes::ReturnStatement& node) + { + if (node.returnExpr) + m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr)); + else + m_currentBlock->Append(SpirvOp::OpReturn); + } + void SpirvAstVisitor::Visit(ShaderNodes::Sample2D& node) { UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index e764cbf49..fd27b2dc6 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -18,6 +18,7 @@ namespace Nz template overloaded(Ts...)->overloaded; } + struct SpirvConstantCache::Eq { bool Compare(const ConstantBool& lhs, const ConstantBool& rhs) const @@ -353,6 +354,12 @@ namespace Nz }, v); } + void Register(const std::vector& lhs) + { + for (std::size_t i = 0; i < lhs.size(); ++i) + cache.Register(*lhs[i]); + } + template void Register(const std::vector& lhs) { diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp index 93af139e0..caddbec43 100644 --- a/src/Nazara/Shader/SpirvExpressionLoad.cpp +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -109,6 +109,11 @@ namespace Nz m_value = Value{ m_writer.ReadLocalVariable(var.name) }; } + void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& var) + { + m_value = Value{ m_writer.ReadParameterVariable(var.name) }; + } + void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var) { auto uniformVar = m_writer.GetUniformVariable(var.name); diff --git a/src/Nazara/Shader/SpirvSection.cpp b/src/Nazara/Shader/SpirvSectionBase.cpp similarity index 85% rename from src/Nazara/Shader/SpirvSection.cpp rename to src/Nazara/Shader/SpirvSectionBase.cpp index c3d62ade3..9451e93db 100644 --- a/src/Nazara/Shader/SpirvSection.cpp +++ b/src/Nazara/Shader/SpirvSectionBase.cpp @@ -2,13 +2,13 @@ // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp -#include +#include #include #include namespace Nz { - std::size_t SpirvSection::Append(const Raw& raw) + std::size_t SpirvSectionBase::AppendRaw(const Raw& raw) { std::size_t offset = GetOutputOffset(); @@ -30,7 +30,7 @@ namespace Nz codepoint |= UInt32(ptr[pos]) << (j * 8); } - Append(codepoint); + AppendRaw(codepoint); } return offset; diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index bdc4c12f3..6316e42ee 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -191,16 +191,17 @@ namespace Nz { UInt32 typeId; UInt32 id; - std::vector paramsId; }; tsl::ordered_map inputIds; tsl::ordered_map outputIds; + tsl::ordered_map parameterIds; tsl::ordered_map uniformIds; std::unordered_map extensionInstructions; std::unordered_map builtinIds; std::unordered_map varToResult; std::vector funcs; + std::vector functionBlocks; std::vector resultIds; UInt32 nextVarIndex = 1; SpirvConstantCache constantTypeCache; //< init after nextVarIndex @@ -307,7 +308,7 @@ namespace Nz builtinData.typeId = GetTypeId(builtinType); builtinData.varId = varId; - state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpvDecorationBuiltIn, builtinDecoration); + state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration); state.builtinIds.emplace(builtin->entry, builtinData); } @@ -329,7 +330,7 @@ namespace Nz state.inputIds.emplace(input.name, std::move(inputData)); if (input.locationIndex) - state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *input.locationIndex); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex); } for (const auto& output : shader.GetOutputs()) @@ -349,7 +350,7 @@ namespace Nz state.outputIds.emplace(output.name, std::move(outputData)); if (output.locationIndex) - state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *output.locationIndex); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex); } for (const auto& uniform : shader.GetUniforms()) @@ -370,8 +371,8 @@ namespace Nz if (uniform.bindingIndex) { - state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationBinding, *uniform.bindingIndex); - state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationDescriptorSet, 0); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex); + state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0); } } @@ -396,77 +397,86 @@ namespace Nz state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); - std::vector blocks; - blocks.emplace_back(*this); + state.functionBlocks.clear(); + state.functionBlocks.emplace_back(*this); + + state.parameterIds.clear(); for (const auto& param : func.parameters) { UInt32 paramResultId = AllocateResultId(); - funcData.paramsId.push_back(paramResultId); + state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId); - blocks.back().Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId); + ExtVar parameterData; + parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function); + parameterData.typeId = GetTypeId(param.type); + parameterData.varId = paramResultId; + + state.parameterIds.emplace(param.name, std::move(parameterData)); } - SpirvAstVisitor visitor(*this, blocks); + SpirvAstVisitor visitor(*this, state.functionBlocks); visitor.Visit(functionStatements[funcIndex]); - if (func.returnType == ShaderNodes::BasicType::Void) - blocks.back().Append(SpirvOp::OpReturn); - else - throw std::runtime_error("returning values from functions is not yet supported"); //< TODO + if (!state.functionBlocks.back().IsTerminated()) + { + assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void)); + state.functionBlocks.back().Append(SpirvOp::OpReturn); + } - blocks.back().Append(SpirvOp::OpFunctionEnd); + for (SpirvBlock& block : state.functionBlocks) + state.instructions.AppendSection(block); - for (SpirvBlock& block : blocks) - state.instructions.Append(block); + state.instructions.Append(SpirvOp::OpFunctionEnd); } - assert(entryPointIndex != std::numeric_limits::max()); - m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo); AppendHeader(); - SpvExecutionModel execModel; - const auto& entryFuncData = shader.GetFunction(entryPointIndex); - const auto& entryFunc = state.funcs[entryPointIndex]; - - assert(m_context.shader); - switch (m_context.shader->GetStage()) + if (entryPointIndex != std::numeric_limits::max()) { - case ShaderStageType::Fragment: - execModel = SpvExecutionModelFragment; - break; + SpvExecutionModel execModel; + const auto& entryFuncData = shader.GetFunction(entryPointIndex); + const auto& entryFunc = state.funcs[entryPointIndex]; - case ShaderStageType::Vertex: - execModel = SpvExecutionModelVertex; - break; + assert(m_context.shader); + switch (m_context.shader->GetStage()) + { + case ShaderStageType::Fragment: + execModel = SpvExecutionModelFragment; + break; - default: - throw std::runtime_error("not yet implemented"); + case ShaderStageType::Vertex: + execModel = SpvExecutionModelVertex; + break; + + default: + throw std::runtime_error("not yet implemented"); + } + + // OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos + + state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender) + { + appender(execModel); + appender(entryFunc.id); + appender(entryFuncData.name); + + for (const auto& [name, varData] : state.builtinIds) + appender(varData.varId); + + for (const auto& [name, varData] : state.inputIds) + appender(varData.varId); + + for (const auto& [name, varData] : state.outputIds) + appender(varData.varId); + }); + + if (m_context.shader->GetStage() == ShaderStageType::Fragment) + state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft); } - // OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos - - state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender) - { - appender(execModel); - appender(entryFunc.id); - appender(entryFuncData.name); - - for (const auto& [name, varData] : state.builtinIds) - appender(varData.varId); - - for (const auto& [name, varData] : state.inputIds) - appender(varData.varId); - - for (const auto& [name, varData] : state.outputIds) - appender(varData.varId); - }); - - if (m_context.shader->GetStage() == ShaderStageType::Fragment) - state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft); - std::vector ret; MergeSections(ret, state.header); MergeSections(ret, state.debugInfo); @@ -489,14 +499,14 @@ namespace Nz void SpirvWriter::AppendHeader() { - m_currentState->header.Append(SpvMagicNumber); //< Spir-V magic number + m_currentState->header.AppendRaw(SpvMagicNumber); //< Spir-V magic number UInt32 version = (m_environment.spvMajorVersion << 16) | m_environment.spvMinorVersion << 8; - m_currentState->header.Append(version); //< Spir-V version number (1.0 for compatibility) - m_currentState->header.Append(0); //< Generator identifier (TODO: Register generator to Khronos) + m_currentState->header.AppendRaw(version); //< Spir-V version number (1.0 for compatibility) + m_currentState->header.AppendRaw(0); //< Generator identifier (TODO: Register generator to Khronos) - m_currentState->header.Append(m_currentState->nextVarIndex); //< Bound (ID count) - m_currentState->header.Append(0); //< Instruction schema (required to be 0 for now) + m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count) + m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now) m_currentState->header.Append(SpirvOp::OpCapability, SpvCapabilityShader); @@ -506,6 +516,20 @@ namespace Nz m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); } + SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector& parameters) + { + std::vector parameterTypes; + parameterTypes.reserve(parameters.size()); + + for (const auto& parameter : parameters) + parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function)); + + return SpirvConstantCache::Function{ + SpirvConstantCache::BuildType(*m_context.shader, retType), + std::move(parameterTypes) + }; + } + UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); @@ -513,18 +537,7 @@ namespace Nz UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector& parameters) { - std::vector parameterTypes; - parameterTypes.reserve(parameters.size()); - - for (const auto& parameter : parameters) - parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type)); - - return m_currentState->constantTypeCache.GetId({ - SpirvConstantCache::Function { - SpirvConstantCache::BuildType(*m_context.shader, retType), - std::move(parameterTypes) - } - }); + return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) }); } auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar& @@ -602,6 +615,22 @@ namespace Nz return it->second; } + UInt32 SpirvWriter::ReadParameterVariable(const std::string& name) + { + auto it = m_currentState->parameterIds.find(name); + assert(it != m_currentState->parameterIds.end()); + + return ReadVariable(it.value()); + } + + std::optional SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache) + { + auto it = m_currentState->parameterIds.find(name); + assert(it != m_currentState->parameterIds.end()); + + return ReadVariable(it.value(), OnlyCache{}); + } + UInt32 SpirvWriter::ReadUniformVariable(const std::string& name) { auto it = m_currentState->uniformIds.find(name); @@ -623,7 +652,7 @@ namespace Nz if (!var.valueId.has_value()) { UInt32 resultId = AllocateResultId(); - m_currentState->instructions.Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId); + m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId); var.valueId = resultId; } @@ -646,18 +675,7 @@ namespace Nz UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector& parameters) { - std::vector parameterTypes; - parameterTypes.reserve(parameters.size()); - - for (const auto& parameter : parameters) - parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type)); - - return m_currentState->constantTypeCache.Register({ - SpirvConstantCache::Function { - SpirvConstantCache::BuildType(*m_context.shader, retType), - std::move(parameterTypes) - } - }); + return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) }); } UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)