From 48b93a9deac57d5d1e54397fd4bdbdcbdb512d75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Thu, 11 Mar 2021 17:50:11 +0100 Subject: [PATCH] Shader: Fix SPIRV generation --- include/Nazara/Shader/SpirvAstVisitor.hpp | 5 +- include/Nazara/Shader/SpirvAstVisitor.inl | 3 +- include/Nazara/Shader/SpirvConstantCache.hpp | 1 + include/Nazara/Shader/SpirvWriter.hpp | 7 +- src/Nazara/Shader/ShaderAstExpressionType.cpp | 2 + src/Nazara/Shader/SpirvAstVisitor.cpp | 10 +-- src/Nazara/Shader/SpirvConstantCache.cpp | 22 ++++- src/Nazara/Shader/SpirvWriter.cpp | 89 +++++++++---------- 8 files changed, 78 insertions(+), 61 deletions(-) diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index ffead5fef..b6536ea2e 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -21,7 +21,7 @@ namespace Nz class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept { public: - inline SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks); + inline SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks, ShaderAst::AstCache* cache); SpirvAstVisitor(const SpirvAstVisitor&) = delete; SpirvAstVisitor(SpirvAstVisitor&&) = delete; ~SpirvAstVisitor() = default; @@ -56,9 +56,10 @@ namespace Nz void PushResultId(UInt32 value); UInt32 PopResultId(); - SpirvBlock* m_currentBlock; + ShaderAst::AstCache* m_cache; std::vector& m_blocks; std::vector m_resultIds; + SpirvBlock* m_currentBlock; SpirvWriter& m_writer; }; } diff --git a/include/Nazara/Shader/SpirvAstVisitor.inl b/include/Nazara/Shader/SpirvAstVisitor.inl index 048f5768e..8694244be 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.inl +++ b/include/Nazara/Shader/SpirvAstVisitor.inl @@ -7,7 +7,8 @@ namespace Nz { - inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks) : + inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector& blocks, ShaderAst::AstCache* cache) : + m_cache(cache), m_blocks(blocks), m_writer(writer) { diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index 54a53584c..7cc829518 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -172,6 +172,7 @@ namespace Nz SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept; static ConstantPtr BuildConstant(const ShaderConstantValue& value); + static TypePtr BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector& parameters); static TypePtr BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass); static TypePtr BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass); static TypePtr BuildType(const ShaderAst::BasicType& type); diff --git a/include/Nazara/Shader/SpirvWriter.hpp b/include/Nazara/Shader/SpirvWriter.hpp index a8af651cd..d0c5f561d 100644 --- a/include/Nazara/Shader/SpirvWriter.hpp +++ b/include/Nazara/Shader/SpirvWriter.hpp @@ -56,10 +56,8 @@ namespace Nz void AppendHeader(); - SpirvConstantCache::Function BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters); - UInt32 GetConstantId(const ShaderConstantValue& value) const; - UInt32 GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector& parameters); + UInt32 GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode); const ExtVar& GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const; const ExtVar& GetInputVariable(const std::string& name) const; const ExtVar& GetOutputVariable(const std::string& name) const; @@ -81,12 +79,13 @@ namespace Nz std::optional ReadVariable(const ExtVar& var, OnlyCache); UInt32 RegisterConstant(const ShaderConstantValue& value); - UInt32 RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters); + UInt32 RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode); UInt32 RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass); UInt32 RegisterType(ShaderAst::ShaderExpressionType type); void WriteLocalVariable(std::string name, UInt32 resultId); + static SpirvConstantCache::TypePtr BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode); static void MergeSections(std::vector& output, const SpirvSection& from); struct Context diff --git a/src/Nazara/Shader/ShaderAstExpressionType.cpp b/src/Nazara/Shader/ShaderAstExpressionType.cpp index be4f238e9..9a6b394d7 100644 --- a/src/Nazara/Shader/ShaderAstExpressionType.cpp +++ b/src/Nazara/Shader/ShaderAstExpressionType.cpp @@ -163,6 +163,8 @@ namespace Nz::ShaderAst void ExpressionTypeVisitor::Visit(IdentifierExpression& node) { + assert(m_cache); + auto scopeIt = m_cache->scopeIdByNode.find(&node); if (scopeIt == m_cache->scopeIdByNode.end()) throw std::runtime_error("internal error"); diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 96a099a21..0dd4737a3 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -39,13 +39,13 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node) { - ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node); + ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache); assert(IsBasicType(resultExprType)); - ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left); + ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache); assert(IsBasicType(leftExprType)); - ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right); + ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache); assert(IsBasicType(rightExprType)); ShaderAst::BasicType resultType = std::get(resultExprType); @@ -582,7 +582,7 @@ namespace Nz { case ShaderAst::IntrinsicType::DotProduct: { - const ShaderAst::ShaderExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); + ShaderAst::ShaderExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache); assert(IsBasicType(vecExprType)); ShaderAst::BasicType vecType = std::get(vecExprType); @@ -626,7 +626,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { - const ShaderAst::ShaderExpressionType& targetExprType = ShaderAst::GetExpressionType(node); + ShaderAst::ShaderExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache); assert(IsBasicType(targetExprType)); ShaderAst::BasicType targetType = std::get(targetExprType); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 05108f1f6..97d40a509 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -570,18 +570,32 @@ namespace Nz }, value)); } + auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector& parameters) -> TypePtr + { + std::vector parameterTypes; + parameterTypes.reserve(parameters.size()); + + for (const auto& parameterType : parameters) + parameterTypes.push_back(BuildPointerType(parameterType, SpirvStorageClass::Function)); + + return std::make_shared(Function{ + BuildType(retType), + std::move(parameterTypes) + }); + } + auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr { - return std::make_shared(SpirvConstantCache::Pointer{ - SpirvConstantCache::BuildType(type), + return std::make_shared(Pointer{ + BuildType(type), storageClass }); } auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr { - return std::make_shared(SpirvConstantCache::Pointer{ - SpirvConstantCache::BuildType(type), + return std::make_shared(Pointer{ + BuildType(type), storageClass }); } diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index d108d7efb..c877db5a0 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -31,6 +31,7 @@ namespace Nz public: using ExtInstList = std::unordered_set; using LocalContainer = std::unordered_set; + using FunctionContainer = std::vector>; PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : m_cache(cache), @@ -79,9 +80,15 @@ namespace Nz void Visit(ShaderAst::DeclareFunctionStatement& node) override { - m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType)); + funcs.emplace_back(node); + + std::vector parameterTypes; for (auto& parameter : node.parameters) - m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type)); + parameterTypes.push_back(parameter.type); + + m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes)); + + AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::DeclareStructStatement& node) override @@ -92,14 +99,14 @@ namespace Nz void Visit(ShaderAst::DeclareVariableStatement& node) override { - variableTypes.insert(node.varType); + m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType)); AstRecursiveVisitor::Visit(node); } void Visit(ShaderAst::IdentifierExpression& node) override { - variableTypes.insert(GetExpressionType(node, m_cache)); + m_constantCache.Register(*SpirvConstantCache::BuildType(GetExpressionType(node, m_cache))); AstRecursiveVisitor::Visit(node); } @@ -122,7 +129,7 @@ namespace Nz } ExtInstList extInsts; - LocalContainer variableTypes; + FunctionContainer funcs; private: ShaderAst::AstCache* m_cache; @@ -210,8 +217,6 @@ namespace Nz m_currentState = nullptr; }); - std::vector functionStatements; - ShaderAst::AstCloner cloner; // Register all extended instruction sets @@ -345,26 +350,22 @@ namespace Nz state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex); state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0); } - } + }*/ - for (const auto& func : shader.GetFunctions()) + for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs) { auto& funcData = state.funcs.emplace_back(); funcData.id = AllocateResultId(); - funcData.typeId = GetFunctionTypeId(func.returnType, func.parameters); + funcData.typeId = GetFunctionTypeId(func); state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name); } - std::size_t entryPointIndex = std::numeric_limits::max(); + std::size_t funcIndex = 0; - for (std::size_t funcIndex = 0; funcIndex < shader.GetFunctionCount(); ++funcIndex) + for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs) { - const auto& func = shader.GetFunction(funcIndex); - if (func.name == "main") - entryPointIndex = funcIndex; - - auto& funcData = state.funcs[funcIndex]; + auto& funcData = state.funcs[funcIndex++]; state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); @@ -386,8 +387,9 @@ namespace Nz state.parameterIds.emplace(param.name, std::move(parameterData)); } - SpirvAstVisitor visitor(*this, state.functionBlocks); - visitor.Visit(functionStatements[funcIndex]); + SpirvAstVisitor visitor(*this, state.functionBlocks, &m_context.cache); + for (const auto& statement : func.statements) + statement->Visit(visitor); if (!state.functionBlocks.back().IsTerminated()) { @@ -405,7 +407,7 @@ namespace Nz AppendHeader(); - if (entryPointIndex != std::numeric_limits::max()) + /*if (entryPointIndex != std::numeric_limits::max()) { SpvExecutionModel execModel; const auto& entryFuncData = shader.GetFunction(entryPointIndex); @@ -415,11 +417,11 @@ namespace Nz switch (m_context.shader->GetStage()) { case ShaderStageType::Fragment: - execModel = SpvExecutionModelFragment; + execModel = SpirvExecutionModel::Fragment; break; case ShaderStageType::Vertex: - execModel = SpvExecutionModelVertex; + execModel = SpirvExecutionModel::Vertex; break; default: @@ -445,15 +447,15 @@ namespace Nz }); if (m_context.shader->GetStage() == ShaderStageType::Fragment) - state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft); + state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft); }*/ std::vector ret; - /*MergeSections(ret, state.header); + MergeSections(ret, state.header); MergeSections(ret, state.debugInfo); MergeSections(ret, state.annotations); MergeSections(ret, state.constants); - MergeSections(ret, state.instructions);*/ + MergeSections(ret, state.instructions); return ret; } @@ -479,26 +481,12 @@ namespace Nz m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count) m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now) - m_currentState->header.Append(SpirvOp::OpCapability, SpvCapabilityShader); + m_currentState->header.Append(SpirvOp::OpCapability, SpirvCapability::Shader); for (const auto& [extInst, resultId] : m_currentState->extensionInstructions) m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst); - m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); - } - - SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) - { - std::vector parameterTypes; - parameterTypes.reserve(parameters.size()); - - for (const auto& parameter : parameters) - parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function)); - - return SpirvConstantCache::Function{ - SpirvConstantCache::BuildType(retType), - std::move(parameterTypes) - }; + m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450); } UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const @@ -506,9 +494,9 @@ namespace Nz return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); } - UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) + UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode) { - return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) }); + return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) }); } auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar& @@ -644,9 +632,9 @@ namespace Nz return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value)); } - UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector& parameters) + UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode) { - return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) }); + return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) }); } UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass) @@ -666,6 +654,17 @@ namespace Nz m_currentState->varToResult.insert_or_assign(std::move(name), resultId); } + SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode) + { + std::vector parameterTypes; + parameterTypes.reserve(functionNode.parameters.size()); + + for (const auto& parameter : functionNode.parameters) + parameterTypes.push_back(parameter.type); + + return SpirvConstantCache::BuildFunctionType(functionNode.returnType, parameterTypes); + } + void SpirvWriter::MergeSections(std::vector& output, const SpirvSection& from) { const std::vector& bytecode = from.GetBytecode();