Shader: Rework Swizzle and add support for swizzle store in SPIRV

This commit is contained in:
Jérôme Leclercq
2021-12-21 14:30:47 +01:00
parent e43a638112
commit 837b72f68e
11 changed files with 151 additions and 102 deletions

View File

@@ -7,6 +7,7 @@
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <numeric>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
@@ -27,9 +28,57 @@ namespace Nz
{
m_block.Append(SpirvOp::OpStore, pointer.pointerId, resultId);
},
[&](const LocalVar& value)
[&](const SwizzledPointer& swizzledPointer)
{
throw std::runtime_error("not yet implemented");
if (swizzledPointer.componentCount > 1)
{
std::size_t vectorSize = swizzledPointer.swizzledType.componentCount;
UInt32 exprTypeId = m_writer.GetTypeId(swizzledPointer.swizzledType);
// Load original value (which will then be shuffled with new value)
UInt32 originalVecId = m_visitor.AllocateResultId();
m_block.Append(SpirvOp::OpLoad, exprTypeId, originalVecId, swizzledPointer.pointerId);
// Build a new composite type using OpVectorShuffle and store it
StackArray<UInt32> indices = NazaraStackArrayNoInit(UInt32, vectorSize);
std::iota(indices.begin(), indices.end(), UInt32(0u)); //< init with regular swizzle (0,1,2,3)
// override with swizzle components
for (std::size_t i = 0; i < swizzledPointer.componentCount; ++i)
indices[swizzledPointer.swizzleIndices[i]] = SafeCast<UInt32>(vectorSize + i);
UInt32 shuffleResultId = m_visitor.AllocateResultId();
m_block.AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
{
appender(exprTypeId);
appender(shuffleResultId);
appender(originalVecId);
appender(resultId);
for (UInt32 index : indices)
appender(index);
});
// Store result
m_block.Append(SpirvOp::OpStore, swizzledPointer.pointerId, shuffleResultId);
}
else
{
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node);
assert(swizzledPointer.componentCount == 1);
UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME
// Access chain
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
UInt32 pointerId = m_visitor.AllocateResultId();
m_block.Append(SpirvOp::OpAccessChain, pointerType, pointerId, swizzledPointer.pointerId, indexId);
m_block.Append(SpirvOp::OpStore, pointerId, resultId);
}
},
[](std::monostate)
{
@@ -67,10 +116,6 @@ namespace Nz
m_value = Pointer { pointer.storage, resultId };
},
[&](const LocalVar& value)
{
throw std::runtime_error("not yet implemented");
},
[](std::monostate)
{
throw std::runtime_error("an internal error occurred");
@@ -80,30 +125,34 @@ namespace Nz
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
{
if (node.componentCount != 1)
throw std::runtime_error("swizzle with more than one component is not yet supported");
node.expression->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
std::visit(overloaded
{
[&](const Pointer& pointer)
{
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
const auto& expressionType = GetExpressionType(*node.expression);
assert(IsVectorType(expressionType));
Int32 indexCount = UnderlyingCast(node.components[0]) - UnderlyingCast(ShaderAst::SwizzleComponent::First);
UInt32 indexId = m_writer.GetConstantId(indexCount);
SwizzledPointer swizzledPointer{ pointer };
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(expressionType);
swizzledPointer.componentCount = node.componentCount;
swizzledPointer.swizzleIndices = node.components;
m_block.Append(SpirvOp::OpAccessChain, pointerType, resultId, pointer.pointerId, indexId);
m_value = Pointer { pointer.storage, resultId };
m_value = swizzledPointer;
},
[&](const LocalVar& value)
[&](SwizzledPointer& swizzledPointer)
{
throw std::runtime_error("not yet implemented");
// Swizzle the swizzle, keep common components
std::array<UInt32, 4> newIndices;
for (std::size_t i = 0; i < node.componentCount; ++i)
{
assert(node.components[i] < node.componentCount);
newIndices[i] = swizzledPointer.swizzleIndices[node.components[i]];
}
swizzledPointer.componentCount = node.componentCount;
swizzledPointer.swizzleIndices = newIndices;
},
[](std::monostate)
{