diff --git a/include/Nazara/Shader/Ast/Enums.hpp b/include/Nazara/Shader/Ast/Enums.hpp index 3a9a91f91..99df966f4 100644 --- a/include/Nazara/Shader/Ast/Enums.hpp +++ b/include/Nazara/Shader/Ast/Enums.hpp @@ -128,14 +128,6 @@ namespace Nz UInt32, //< ui32 }; - enum class SwizzleComponent - { - First, - Second, - Third, - Fourth - }; - enum class UnaryType { LogicalNot, //< !v diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 7dfceec53..555f06b28 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -178,7 +178,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - std::array components; + std::array components; std::size_t componentCount; ExpressionPtr expression; }; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 5e0a77400..780c795f9 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -133,7 +133,7 @@ namespace Nz::ShaderBuilder struct Swizzle { - inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const; }; struct Unary diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 52188a99c..bc964b26a 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -277,7 +277,7 @@ namespace Nz::ShaderBuilder return returnNode; } - inline std::unique_ptr Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const + inline std::unique_ptr Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector swizzleComponents) const { auto swizzleNode = std::make_unique(); swizzleNode->expression = std::move(expression); @@ -285,7 +285,10 @@ namespace Nz::ShaderBuilder assert(swizzleComponents.size() <= swizzleNode->components.size()); swizzleNode->componentCount = swizzleComponents.size(); for (std::size_t i = 0; i < swizzleNode->componentCount; ++i) + { + assert(swizzleComponents[i] >= 0 && swizzleComponents[i] <= 4); swizzleNode->components[i] = swizzleComponents[i]; + } return swizzleNode; } diff --git a/include/Nazara/Shader/SpirvExpressionStore.hpp b/include/Nazara/Shader/SpirvExpressionStore.hpp index 8d3ea3884..183b54bc8 100644 --- a/include/Nazara/Shader/SpirvExpressionStore.hpp +++ b/include/Nazara/Shader/SpirvExpressionStore.hpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace Nz { @@ -37,21 +38,23 @@ namespace Nz SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete; private: - struct LocalVar - { - std::string varName; - }; - struct Pointer { SpirvStorageClass storage; UInt32 pointerId; }; + struct SwizzledPointer : Pointer + { + ShaderAst::VectorType swizzledType; + std::array swizzleIndices; + std::size_t componentCount; + }; + SpirvAstVisitor& m_visitor; SpirvBlock& m_block; SpirvWriter& m_writer; - std::variant m_value; + std::variant m_value; }; } diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 94c50af48..0526f7842 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -200,26 +200,29 @@ namespace Nz::ShaderAst case 'r': case 'x': case 's': - swizzle->components[j] = SwizzleComponent::First; + swizzle->components[j] = 0u; break; case 'g': case 'y': case 't': - swizzle->components[j] = SwizzleComponent::Second; + swizzle->components[j] = 1u; break; case 'b': case 'z': case 'p': - swizzle->components[j] = SwizzleComponent::Third; + swizzle->components[j] = 2u; break; case 'a': case 'w': case 'q': - swizzle->components[j] = SwizzleComponent::Fourth; + swizzle->components[j] = 3u; break; + + default: + throw AstError{ "unexpected character '" + std::string(swizzleStr) + "' on swizzle " }; } } @@ -303,8 +306,8 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->parameters.reserve(node.parameters.size()); - for (std::size_t i = 0; i < node.parameters.size(); ++i) - clone->parameters.push_back(CloneExpression(node.parameters[i])); + for (const auto& parameter : node.parameters) + clone->parameters.push_back(CloneExpression(parameter)); std::size_t targetFuncIndex; if (std::holds_alternative(node.targetFunction)) @@ -495,6 +498,12 @@ namespace Nz::ShaderAst if (node.componentCount > 4) throw AstError{ "Cannot swizzle more than four elements" }; + for (UInt32 swizzleIndex : node.components) + { + if (swizzleIndex >= 4) + throw AstError{ "invalid swizzle" }; + } + MandatoryExpr(node.expression); auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); @@ -1315,12 +1324,10 @@ namespace Nz::ShaderAst } ExpressionType exprType = GetExpressionType(*node.expr); - for (std::size_t i = 0; i < node.indices.size(); ++i) + for (const auto& indexExpr : node.indices) { if (IsStructType(exprType)) { - auto& indexExpr = node.indices[i]; - const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr); if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) throw AstError{ "struct can only be accessed with constant i32 indices" }; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 5bc853672..d9573a754 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -1195,27 +1195,9 @@ namespace Nz Visit(node.expression, true); Append("."); + const char* componentStr = "xyzw"; for (std::size_t i = 0; i < node.componentCount; ++i) - { - switch (node.components[i]) - { - case ShaderAst::SwizzleComponent::First: - Append("x"); - break; - - case ShaderAst::SwizzleComponent::Second: - Append("y"); - break; - - case ShaderAst::SwizzleComponent::Third: - Append("z"); - break; - - case ShaderAst::SwizzleComponent::Fourth: - Append("w"); - break; - } - } + Append(componentStr[node.components[i]]); } void GlslWriter::Visit(ShaderAst::VariableExpression& node) diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index dcad1f098..5a198bc71 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -881,27 +881,9 @@ namespace Nz Visit(node.expression, true); Append("."); + const char* componentStr = "xyzw"; for (std::size_t i = 0; i < node.componentCount; ++i) - { - switch (node.components[i]) - { - case ShaderAst::SwizzleComponent::First: - Append("x"); - break; - - case ShaderAst::SwizzleComponent::Second: - Append("y"); - break; - - case ShaderAst::SwizzleComponent::Third: - Append("z"); - break; - - case ShaderAst::SwizzleComponent::Fourth: - Append("w"); - break; - } - } + Append(componentStr[node.components[i]]); } void LangWriter::Visit(ShaderAst::VariableExpression& node) diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index 56583ab15..2523bdb60 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -906,8 +906,9 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) { + const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression); + UInt32 exprResultId = EvaluateExpression(node.expression); - UInt32 resultId = m_writer.AllocateResultId(); const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node); @@ -917,31 +918,61 @@ namespace Nz const ShaderAst::VectorType& targetType = std::get(targetExprType); - // Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands - m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) + UInt32 resultId = m_writer.AllocateResultId(); + if (IsVectorType(swizzledExpressionType)) { - appender(m_writer.GetTypeId(targetType)); - appender(resultId); - appender(exprResultId); - appender(exprResultId); + // Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands + m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) + { + appender(m_writer.GetTypeId(targetType)); + appender(resultId); + appender(exprResultId); + appender(exprResultId); - for (std::size_t i = 0; i < node.componentCount; ++i) - appender(UInt32(node.components[i])); - }); + for (std::size_t i = 0; i < node.componentCount; ++i) + appender(node.components[i]); + }); + } + else + { + assert(IsPrimitiveType(swizzledExpressionType)); + + // Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct + m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) + { + appender(m_writer.GetTypeId(targetType)); + appender(resultId); + + for (std::size_t i = 0; i < node.componentCount; ++i) + appender(exprResultId); + }); + } + + PushResultId(resultId); } - else + else if (IsVectorType(swizzledExpressionType)) { assert(IsPrimitiveType(targetExprType)); - ShaderAst::PrimitiveType targetType = std::get(targetExprType); // Extract a single component from the vector assert(node.componentCount == 1); - m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderAst::SwizzleComponent::First) ); - } + UInt32 resultId = m_writer.AllocateResultId(); + m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, node.components[0]); - PushResultId(resultId); + PushResultId(resultId); + } + else + { + // Swizzling a primitive to itself (a.x for example), don't do anything + assert(IsPrimitiveType(swizzledExpressionType)); + assert(IsPrimitiveType(targetExprType)); + assert(node.componentCount == 1); + assert(node.components[0] == 0); + + PushResultId(exprResultId); + } } void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node) diff --git a/src/Nazara/Shader/SpirvExpressionStore.cpp b/src/Nazara/Shader/SpirvExpressionStore.cpp index f01ef2ae7..7bb15e53e 100644 --- a/src/Nazara/Shader/SpirvExpressionStore.cpp +++ b/src/Nazara/Shader/SpirvExpressionStore.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace Nz @@ -27,9 +28,57 @@ namespace Nz { m_block.Append(SpirvOp::OpStore, pointer.pointerId, resultId); }, - [&](const LocalVar& value) + [&](const SwizzledPointer& swizzledPointer) { - throw std::runtime_error("not yet implemented"); + if (swizzledPointer.componentCount > 1) + { + std::size_t vectorSize = swizzledPointer.swizzledType.componentCount; + + UInt32 exprTypeId = m_writer.GetTypeId(swizzledPointer.swizzledType); + + // Load original value (which will then be shuffled with new value) + UInt32 originalVecId = m_visitor.AllocateResultId(); + m_block.Append(SpirvOp::OpLoad, exprTypeId, originalVecId, swizzledPointer.pointerId); + + // Build a new composite type using OpVectorShuffle and store it + StackArray indices = NazaraStackArrayNoInit(UInt32, vectorSize); + std::iota(indices.begin(), indices.end(), UInt32(0u)); //< init with regular swizzle (0,1,2,3) + + // override with swizzle components + for (std::size_t i = 0; i < swizzledPointer.componentCount; ++i) + indices[swizzledPointer.swizzleIndices[i]] = SafeCast(vectorSize + i); + + UInt32 shuffleResultId = m_visitor.AllocateResultId(); + m_block.AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) + { + appender(exprTypeId); + appender(shuffleResultId); + + appender(originalVecId); + appender(resultId); + + for (UInt32 index : indices) + appender(index); + }); + + // Store result + m_block.Append(SpirvOp::OpStore, swizzledPointer.pointerId, shuffleResultId); + } + else + { + const ShaderAst::ExpressionType& exprType = GetExpressionType(*node); + + assert(swizzledPointer.componentCount == 1); + + UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME + + // Access chain + UInt32 indexId = m_writer.GetConstantId(SafeCast(swizzledPointer.swizzleIndices[0])); + + UInt32 pointerId = m_visitor.AllocateResultId(); + m_block.Append(SpirvOp::OpAccessChain, pointerType, pointerId, swizzledPointer.pointerId, indexId); + m_block.Append(SpirvOp::OpStore, pointerId, resultId); + } }, [](std::monostate) { @@ -67,10 +116,6 @@ namespace Nz m_value = Pointer { pointer.storage, resultId }; }, - [&](const LocalVar& value) - { - throw std::runtime_error("not yet implemented"); - }, [](std::monostate) { throw std::runtime_error("an internal error occurred"); @@ -80,30 +125,34 @@ namespace Nz void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node) { - if (node.componentCount != 1) - throw std::runtime_error("swizzle with more than one component is not yet supported"); - node.expression->Visit(*this); - const ShaderAst::ExpressionType& exprType = GetExpressionType(node); - std::visit(overloaded { [&](const Pointer& pointer) { - UInt32 resultId = m_visitor.AllocateResultId(); - UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME + const auto& expressionType = GetExpressionType(*node.expression); + assert(IsVectorType(expressionType)); - Int32 indexCount = UnderlyingCast(node.components[0]) - UnderlyingCast(ShaderAst::SwizzleComponent::First); - UInt32 indexId = m_writer.GetConstantId(indexCount); + SwizzledPointer swizzledPointer{ pointer }; + swizzledPointer.swizzledType = std::get(expressionType); + swizzledPointer.componentCount = node.componentCount; + swizzledPointer.swizzleIndices = node.components; - m_block.Append(SpirvOp::OpAccessChain, pointerType, resultId, pointer.pointerId, indexId); - - m_value = Pointer { pointer.storage, resultId }; + m_value = swizzledPointer; }, - [&](const LocalVar& value) + [&](SwizzledPointer& swizzledPointer) { - throw std::runtime_error("not yet implemented"); + // Swizzle the swizzle, keep common components + std::array newIndices; + for (std::size_t i = 0; i < node.componentCount; ++i) + { + assert(node.components[i] < node.componentCount); + newIndices[i] = swizzledPointer.swizzleIndices[node.components[i]]; + } + + swizzledPointer.componentCount = node.componentCount; + swizzledPointer.swizzleIndices = newIndices; }, [](std::monostate) { diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 8e1503dc6..81920c6ef 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -359,7 +359,7 @@ namespace Nz for (std::size_t i = 0; i < node.componentCount; ++i) { - Int32 indexCount = UnderlyingCast(node.components[i]) - UnderlyingCast(ShaderAst::SwizzleComponent::First); + Int32 indexCount = SafeCast(node.components[i]); m_constantCache.Register(*m_constantCache.BuildConstant(indexCount)); }