diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index 91ad031ad..571af6cc5 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -4,13 +4,14 @@ #pragma once -#ifndef NAZARA_SPIRVEXPRESSIONLOAD_HPP -#define NAZARA_SPIRVEXPRESSIONLOAD_HPP +#ifndef NAZARA_SPIRVASTVISITOR_HPP +#define NAZARA_SPIRVASTVISITOR_HPP #include #include #include #include +#include #include namespace Nz @@ -20,7 +21,7 @@ namespace Nz class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAstVisitorExcept { public: - inline SpirvAstVisitor(SpirvWriter& writer); + inline SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks); SpirvAstVisitor(const SpirvAstVisitor&) = delete; SpirvAstVisitor(SpirvAstVisitor&&) = delete; ~SpirvAstVisitor() = default; @@ -31,6 +32,7 @@ namespace Nz void Visit(ShaderNodes::AccessMember& node) override; void Visit(ShaderNodes::AssignOp& node) override; void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Branch& node) override; void Visit(ShaderNodes::Cast& node) override; void Visit(ShaderNodes::ConditionalExpression& node) override; void Visit(ShaderNodes::ConditionalStatement& node) override; @@ -51,6 +53,8 @@ namespace Nz void PushResultId(UInt32 value); UInt32 PopResultId(); + SpirvBlock* m_currentBlock; + std::vector& m_blocks; std::vector m_resultIds; SpirvWriter& m_writer; }; diff --git a/include/Nazara/Shader/SpirvAstVisitor.inl b/include/Nazara/Shader/SpirvAstVisitor.inl index 87dc93a54..048f5768e 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.inl +++ b/include/Nazara/Shader/SpirvAstVisitor.inl @@ -7,9 +7,11 @@ namespace Nz { - inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer) : + inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks) : + m_blocks(blocks), m_writer(writer) { + m_currentBlock = &m_blocks.back(); } } diff --git a/include/Nazara/Shader/SpirvBlock.hpp b/include/Nazara/Shader/SpirvBlock.hpp new file mode 100644 index 000000000..5fbbe0843 --- /dev/null +++ b/include/Nazara/Shader/SpirvBlock.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SPIRVBLOCK_HPP +#define NAZARA_SPIRVBLOCK_HPP + +#include +#include +#include +#include +#include + +namespace Nz +{ + class NAZARA_SHADER_API SpirvBlock : public SpirvSection + { + public: + inline SpirvBlock(SpirvWriter& writer); + SpirvBlock(const SpirvBlock&) = default; + SpirvBlock(SpirvBlock&&) = default; + ~SpirvBlock() = default; + + inline UInt32 GetLabelId() const; + + SpirvBlock& operator=(const SpirvBlock&) = delete; + SpirvBlock& operator=(SpirvBlock&&) = default; + + private: + UInt32 m_labelId; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/SpirvBlock.inl b/include/Nazara/Shader/SpirvBlock.inl new file mode 100644 index 000000000..3ce1f3050 --- /dev/null +++ b/include/Nazara/Shader/SpirvBlock.inl @@ -0,0 +1,22 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline SpirvBlock::SpirvBlock(SpirvWriter& writer) + { + m_labelId = writer.AllocateResultId(); + Append(SpirvOp::OpLabel, m_labelId); + } + + inline UInt32 SpirvBlock::GetLabelId() const + { + return m_labelId; + } +} + +#include diff --git a/include/Nazara/Shader/SpirvExpressionLoad.hpp b/include/Nazara/Shader/SpirvExpressionLoad.hpp index b237c720a..bb44bba71 100644 --- a/include/Nazara/Shader/SpirvExpressionLoad.hpp +++ b/include/Nazara/Shader/SpirvExpressionLoad.hpp @@ -4,8 +4,8 @@ #pragma once -#ifndef NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP -#define NAZARA_SPIRVEXPRESSIONLOADACCESSMEMBER_HPP +#ifndef NAZARA_SPIRVEXPRESSIONLOAD_HPP +#define NAZARA_SPIRVEXPRESSIONLOAD_HPP #include #include @@ -16,12 +16,13 @@ namespace Nz { + class SpirvBlock; class SpirvWriter; class NAZARA_SHADER_API SpirvExpressionLoad : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept { public: - inline SpirvExpressionLoad(SpirvWriter& writer); + inline SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block); SpirvExpressionLoad(const SpirvExpressionLoad&) = delete; SpirvExpressionLoad(SpirvExpressionLoad&&) = delete; ~SpirvExpressionLoad() = default; @@ -53,6 +54,7 @@ namespace Nz UInt32 resultId; }; + SpirvBlock& m_block; SpirvWriter& m_writer; std::variant m_value; }; diff --git a/include/Nazara/Shader/SpirvExpressionLoad.inl b/include/Nazara/Shader/SpirvExpressionLoad.inl index 966aae912..6d5aff9cb 100644 --- a/include/Nazara/Shader/SpirvExpressionLoad.inl +++ b/include/Nazara/Shader/SpirvExpressionLoad.inl @@ -7,7 +7,8 @@ namespace Nz { - inline SpirvExpressionLoad::SpirvExpressionLoad(SpirvWriter& writer) : + inline SpirvExpressionLoad::SpirvExpressionLoad(SpirvWriter& writer, SpirvBlock& block) : + m_block(block), m_writer(writer) { } diff --git a/include/Nazara/Shader/SpirvExpressionStore.hpp b/include/Nazara/Shader/SpirvExpressionStore.hpp index d7d37e39d..26c2b5f48 100644 --- a/include/Nazara/Shader/SpirvExpressionStore.hpp +++ b/include/Nazara/Shader/SpirvExpressionStore.hpp @@ -15,13 +15,13 @@ namespace Nz { - class SpirvSection; + class SpirvBlock; class SpirvWriter; class NAZARA_SHADER_API SpirvExpressionStore : public ShaderAstVisitorExcept, public ShaderVarVisitorExcept { public: - inline SpirvExpressionStore(SpirvWriter& writer); + inline SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block); SpirvExpressionStore(const SpirvExpressionStore&) = delete; SpirvExpressionStore(SpirvExpressionStore&&) = delete; ~SpirvExpressionStore() = default; @@ -53,6 +53,7 @@ namespace Nz UInt32 resultId; }; + SpirvBlock& m_block; SpirvWriter& m_writer; std::variant m_value; }; diff --git a/include/Nazara/Shader/SpirvExpressionStore.inl b/include/Nazara/Shader/SpirvExpressionStore.inl index 558a2aee8..771624788 100644 --- a/include/Nazara/Shader/SpirvExpressionStore.inl +++ b/include/Nazara/Shader/SpirvExpressionStore.inl @@ -7,7 +7,8 @@ namespace Nz { - inline SpirvExpressionStore::SpirvExpressionStore(SpirvWriter& writer) : + inline SpirvExpressionStore::SpirvExpressionStore(SpirvWriter& writer, SpirvBlock& block) : + m_block(block), m_writer(writer) { } diff --git a/include/Nazara/Shader/SpirvSection.hpp b/include/Nazara/Shader/SpirvSection.hpp index b116b212f..b759cbe2e 100644 --- a/include/Nazara/Shader/SpirvSection.hpp +++ b/include/Nazara/Shader/SpirvSection.hpp @@ -22,8 +22,8 @@ namespace Nz struct Raw; SpirvSection() = default; - SpirvSection(const SpirvSection& cache) = default; - SpirvSection(SpirvSection&& cache) = default; + SpirvSection(const SpirvSection&) = default; + SpirvSection(SpirvSection&&) = default; ~SpirvSection() = default; inline std::size_t Append(const char* str); @@ -35,20 +35,21 @@ namespace Nz inline std::size_t Append(std::initializer_list codepoints); template std::size_t Append(SpirvOp opcode, const Args&... args); template std::size_t AppendVariadic(SpirvOp opcode, F&& callback); - template std::size_t Append(T value); + inline std::size_t Append(const SpirvSection& section); + template || std::is_enum_v>> std::size_t Append(T value); inline unsigned int CountWord(const char* str); inline unsigned int CountWord(const std::string_view& str); inline unsigned int CountWord(const std::string& str); inline unsigned int CountWord(const Raw& raw); - template unsigned int CountWord(const T& value); + template || std::is_enum_v>> unsigned int CountWord(const T& value); template unsigned int CountWord(const T1& value, const T2& value2, const Args&... rest); inline const std::vector& GetBytecode() const; inline std::size_t GetOutputOffset() const; - SpirvSection& operator=(const SpirvSection& cache) = delete; - SpirvSection& operator=(SpirvSection&& cache) = default; + SpirvSection& operator=(const SpirvSection&) = delete; + SpirvSection& operator=(SpirvSection&&) = default; struct OpSize { diff --git a/include/Nazara/Shader/SpirvSection.inl b/include/Nazara/Shader/SpirvSection.inl index e86eb59ff..e6a400d06 100644 --- a/include/Nazara/Shader/SpirvSection.inl +++ b/include/Nazara/Shader/SpirvSection.inl @@ -61,6 +61,17 @@ namespace Nz return offset; } + inline std::size_t SpirvSection::Append(const SpirvSection& section) + { + const std::vector& bytecode = section.GetBytecode(); + + std::size_t offset = GetOutputOffset(); + m_bytecode.resize(offset + bytecode.size()); + std::copy(bytecode.begin(), bytecode.end(), m_bytecode.begin() + offset); + + return offset; + } + template std::size_t SpirvSection::Append(SpirvOp opcode, const Args&... args) { @@ -89,13 +100,13 @@ namespace Nz return offset; } - template + template std::size_t SpirvSection::Append(T value) { return Append(static_cast(value)); } - template + template unsigned int SpirvSection::CountWord(const T& /*value*/) { return 1; diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index 62227ea4e..37ebc6ea0 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -26,6 +26,7 @@ namespace Nz class NAZARA_SHADER_API SpirvWriter : public ShaderWriter { friend class SpirvAstVisitor; + friend class SpirvBlock; friend class SpirvExpressionLoad; friend class SpirvExpressionStore; friend class SpirvVisitor; @@ -62,7 +63,6 @@ namespace Nz const ExtVar& GetInputVariable(const std::string& name) const; const ExtVar& GetOutputVariable(const std::string& name) const; const ExtVar& GetUniformVariable(const std::string& name) const; - SpirvSection& GetInstructions(); UInt32 GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const; UInt32 GetTypeId(const ShaderExpressionType& type) const; diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index e012d6de3..f76c8807c 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -22,7 +22,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderNodes::AccessMember& node) { - SpirvExpressionLoad accessMemberVisitor(m_writer); + SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock); PushResultId(accessMemberVisitor.Evaluate(node)); } @@ -30,7 +30,7 @@ namespace Nz { UInt32 resultId = EvaluateExpression(node.right); - SpirvExpressionStore storeVisitor(m_writer); + SpirvExpressionStore storeVisitor(m_writer, *m_currentBlock); storeVisitor.Store(node.left, resultId); PushResultId(resultId); @@ -438,10 +438,63 @@ namespace Nz if (swapOperands) std::swap(leftOperand, rightOperand); - m_writer.GetInstructions().Append(op, m_writer.GetTypeId(resultType), resultId, leftOperand, rightOperand); + m_currentBlock->Append(op, m_writer.GetTypeId(resultType), resultId, leftOperand, rightOperand); PushResultId(resultId); } + void SpirvAstVisitor::Visit(ShaderNodes::Branch& node) + { + assert(!node.condStatements.empty()); + auto& firstCond = node.condStatements.front(); + + UInt32 previousConditionId = EvaluateExpression(firstCond.condition); + SpirvBlock previousContentBlock(m_writer); + m_currentBlock = &previousContentBlock; + Visit(firstCond.statement); + + std::optional nextBlock; + for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex) + { + const auto& statement = node.condStatements[statementIndex]; + + SpirvBlock contentBlock(m_writer); + + m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId()); + + previousConditionId = EvaluateExpression(statement.condition); + m_blocks.emplace_back(std::move(previousContentBlock)); + previousContentBlock = std::move(contentBlock); + + m_currentBlock = &previousContentBlock; + Visit(statement.statement); + } + + SpirvBlock mergeBlock(m_writer); + + if (node.elseStatement) + { + SpirvBlock elseBlock(m_writer); + + m_currentBlock = &elseBlock; + Visit(node.elseStatement); + + elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice + + m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId()); + m_blocks.emplace_back(std::move(previousContentBlock)); + m_blocks.emplace_back(std::move(elseBlock)); + } + else + { + m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId()); + m_blocks.emplace_back(std::move(previousContentBlock)); + } + + m_blocks.emplace_back(std::move(mergeBlock)); + + m_currentBlock = &m_blocks.back(); + } + void SpirvAstVisitor::Visit(ShaderNodes::Cast& node) { const ShaderExpressionType& targetExprType = node.exprType; @@ -461,7 +514,7 @@ namespace Nz UInt32 resultId = m_writer.AllocateResultId(); - m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) + m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) { appender(m_writer.GetTypeId(targetType)); appender(resultId); @@ -508,7 +561,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderNodes::Discard& /*node*/) { - m_writer.GetInstructions().Append(SpirvOp::OpKill); + m_currentBlock->Append(SpirvOp::OpKill); } void SpirvAstVisitor::Visit(ShaderNodes::ExpressionStatement& node) @@ -519,7 +572,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderNodes::Identifier& node) { - SpirvExpressionLoad loadVisitor(m_writer); + SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock); PushResultId(loadVisitor.Evaluate(node)); } @@ -541,7 +594,7 @@ namespace Nz UInt32 resultId = m_writer.AllocateResultId(); - m_writer.GetInstructions().Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2); + m_currentBlock->Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2); PushResultId(resultId); break; } @@ -560,7 +613,7 @@ namespace Nz UInt32 coordinatesId = EvaluateExpression(node.coordinates); UInt32 resultId = m_writer.AllocateResultId(); - m_writer.GetInstructions().Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId); + m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId); PushResultId(resultId); } @@ -583,7 +636,7 @@ namespace Nz if (node.componentCount > 1) { // Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands - m_writer.GetInstructions().AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) + m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) { appender(m_writer.GetTypeId(targetType)); appender(resultId); @@ -599,7 +652,7 @@ namespace Nz // Extract a single component from the vector assert(node.componentCount == 1); - m_writer.GetInstructions().Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) ); + m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) ); } PushResultId(resultId); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index b8489ce76..e764cbf49 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -682,7 +682,7 @@ namespace Nz using ConstantType = std::decay_t; if constexpr (std::is_same_v) - constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, resultId); + constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, GetId({ Bool{} }), resultId); else if constexpr (std::is_same_v) { constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender) diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp index 026fefee8..b82d9b046 100644 --- a/src/Nazara/Shader/SpirvExpressionLoad.cpp +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -26,7 +26,7 @@ namespace Nz { UInt32 resultId = m_writer.AllocateResultId(); - m_writer.GetInstructions().Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId); + m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId); return resultId; }, @@ -53,7 +53,7 @@ namespace Nz UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME UInt32 typeId = m_writer.GetTypeId(node.exprType); - m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) + m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) { appender(pointerType); appender(resultId); @@ -70,7 +70,7 @@ namespace Nz UInt32 resultId = m_writer.AllocateResultId(); UInt32 typeId = m_writer.GetTypeId(node.exprType); - m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender) + m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender) { appender(typeId); appender(resultId); diff --git a/src/Nazara/Shader/SpirvExpressionStore.cpp b/src/Nazara/Shader/SpirvExpressionStore.cpp index 9123bdcf4..a0c5511d1 100644 --- a/src/Nazara/Shader/SpirvExpressionStore.cpp +++ b/src/Nazara/Shader/SpirvExpressionStore.cpp @@ -3,7 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include -#include +#include #include #include @@ -23,7 +23,7 @@ namespace Nz { [&](const Pointer& pointer) { - m_writer.GetInstructions().Append(SpirvOp::OpStore, pointer.resultId, resultId); + m_block.Append(SpirvOp::OpStore, pointer.resultId, resultId); }, [&](const LocalVar& value) { @@ -47,7 +47,7 @@ namespace Nz UInt32 resultId = m_writer.AllocateResultId(); UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME - m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) + m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender) { appender(pointerType); appender(resultId); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 0e1af9c0b..bdc4c12f3 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -395,23 +396,29 @@ namespace Nz state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); - state.instructions.Append(SpirvOp::OpLabel, AllocateResultId()); + std::vector blocks; + blocks.emplace_back(*this); for (const auto& param : func.parameters) { UInt32 paramResultId = AllocateResultId(); funcData.paramsId.push_back(paramResultId); - state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId); + blocks.back().Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId); } - SpirvAstVisitor visitor(*this); + SpirvAstVisitor visitor(*this, blocks); visitor.Visit(functionStatements[funcIndex]); if (func.returnType == ShaderNodes::BasicType::Void) - state.instructions.Append(SpirvOp::OpReturn); + blocks.back().Append(SpirvOp::OpReturn); + else + throw std::runtime_error("returning values from functions is not yet supported"); //< TODO - state.instructions.Append(SpirvOp::OpFunctionEnd); + blocks.back().Append(SpirvOp::OpFunctionEnd); + + for (SpirvBlock& block : blocks) + state.instructions.Append(block); } assert(entryPointIndex != std::numeric_limits::max()); @@ -552,11 +559,6 @@ namespace Nz return it.value(); } - SpirvSection& SpirvWriter::GetInstructions() - { - return m_currentState->instructions; - } - UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const { return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));