Shader: Rework Swizzle and add support for swizzle store in SPIRV

This commit is contained in:
Jérôme Leclercq 2021-12-21 14:30:47 +01:00
parent e43a638112
commit 837b72f68e
11 changed files with 151 additions and 102 deletions

View File

@ -128,14 +128,6 @@ namespace Nz
UInt32, //< ui32
};
enum class SwizzleComponent
{
First,
Second,
Third,
Fourth
};
enum class UnaryType
{
LogicalNot, //< !v

View File

@ -178,7 +178,7 @@ namespace Nz::ShaderAst
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::array<SwizzleComponent, 4> components;
std::array<UInt32, 4> components;
std::size_t componentCount;
ExpressionPtr expression;
};

View File

@ -133,7 +133,7 @@ namespace Nz::ShaderBuilder
struct Swizzle
{
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::vector<ShaderAst::SwizzleComponent> swizzleComponents) const;
inline std::unique_ptr<ShaderAst::SwizzleExpression> operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const;
};
struct Unary

View File

@ -277,7 +277,7 @@ namespace Nz::ShaderBuilder
return returnNode;
}
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector<ShaderAst::SwizzleComponent> swizzleComponents) const
inline std::unique_ptr<ShaderAst::SwizzleExpression> Impl::Swizzle::operator()(ShaderAst::ExpressionPtr expression, std::vector<UInt32> swizzleComponents) const
{
auto swizzleNode = std::make_unique<ShaderAst::SwizzleExpression>();
swizzleNode->expression = std::move(expression);
@ -285,7 +285,10 @@ namespace Nz::ShaderBuilder
assert(swizzleComponents.size() <= swizzleNode->components.size());
swizzleNode->componentCount = swizzleComponents.size();
for (std::size_t i = 0; i < swizzleNode->componentCount; ++i)
{
assert(swizzleComponents[i] >= 0 && swizzleComponents[i] <= 4);
swizzleNode->components[i] = swizzleComponents[i];
}
return swizzleNode;
}

View File

@ -11,6 +11,7 @@
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/Ast/AstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/Ast/Enums.hpp>
namespace Nz
{
@ -37,21 +38,23 @@ namespace Nz
SpirvExpressionStore& operator=(SpirvExpressionStore&&) = delete;
private:
struct LocalVar
{
std::string varName;
};
struct Pointer
{
SpirvStorageClass storage;
UInt32 pointerId;
};
struct SwizzledPointer : Pointer
{
ShaderAst::VectorType swizzledType;
std::array<UInt32, 4> swizzleIndices;
std::size_t componentCount;
};
SpirvAstVisitor& m_visitor;
SpirvBlock& m_block;
SpirvWriter& m_writer;
std::variant<std::monostate, LocalVar, Pointer> m_value;
std::variant<std::monostate, Pointer, SwizzledPointer> m_value;
};
}

View File

@ -200,26 +200,29 @@ namespace Nz::ShaderAst
case 'r':
case 'x':
case 's':
swizzle->components[j] = SwizzleComponent::First;
swizzle->components[j] = 0u;
break;
case 'g':
case 'y':
case 't':
swizzle->components[j] = SwizzleComponent::Second;
swizzle->components[j] = 1u;
break;
case 'b':
case 'z':
case 'p':
swizzle->components[j] = SwizzleComponent::Third;
swizzle->components[j] = 2u;
break;
case 'a':
case 'w':
case 'q':
swizzle->components[j] = SwizzleComponent::Fourth;
swizzle->components[j] = 3u;
break;
default:
throw AstError{ "unexpected character '" + std::string(swizzleStr) + "' on swizzle " };
}
}
@ -303,8 +306,8 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<CallFunctionExpression>();
clone->parameters.reserve(node.parameters.size());
for (std::size_t i = 0; i < node.parameters.size(); ++i)
clone->parameters.push_back(CloneExpression(node.parameters[i]));
for (const auto& parameter : node.parameters)
clone->parameters.push_back(CloneExpression(parameter));
std::size_t targetFuncIndex;
if (std::holds_alternative<std::string>(node.targetFunction))
@ -495,6 +498,12 @@ namespace Nz::ShaderAst
if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" };
for (UInt32 swizzleIndex : node.components)
{
if (swizzleIndex >= 4)
throw AstError{ "invalid swizzle" };
}
MandatoryExpr(node.expression);
auto clone = static_unique_pointer_cast<SwizzleExpression>(AstCloner::Clone(node));
@ -1315,12 +1324,10 @@ namespace Nz::ShaderAst
}
ExpressionType exprType = GetExpressionType(*node.expr);
for (std::size_t i = 0; i < node.indices.size(); ++i)
for (const auto& indexExpr : node.indices)
{
if (IsStructType(exprType))
{
auto& indexExpr = node.indices[i];
const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr);
if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 })
throw AstError{ "struct can only be accessed with constant i32 indices" };

View File

@ -1195,27 +1195,9 @@ namespace Nz
Visit(node.expression, true);
Append(".");
const char* componentStr = "xyzw";
for (std::size_t i = 0; i < node.componentCount; ++i)
{
switch (node.components[i])
{
case ShaderAst::SwizzleComponent::First:
Append("x");
break;
case ShaderAst::SwizzleComponent::Second:
Append("y");
break;
case ShaderAst::SwizzleComponent::Third:
Append("z");
break;
case ShaderAst::SwizzleComponent::Fourth:
Append("w");
break;
}
}
Append(componentStr[node.components[i]]);
}
void GlslWriter::Visit(ShaderAst::VariableExpression& node)

View File

@ -881,27 +881,9 @@ namespace Nz
Visit(node.expression, true);
Append(".");
const char* componentStr = "xyzw";
for (std::size_t i = 0; i < node.componentCount; ++i)
{
switch (node.components[i])
{
case ShaderAst::SwizzleComponent::First:
Append("x");
break;
case ShaderAst::SwizzleComponent::Second:
Append("y");
break;
case ShaderAst::SwizzleComponent::Third:
Append("z");
break;
case ShaderAst::SwizzleComponent::Fourth:
Append("w");
break;
}
}
Append(componentStr[node.components[i]]);
}
void LangWriter::Visit(ShaderAst::VariableExpression& node)

View File

@ -906,8 +906,9 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = m_writer.AllocateResultId();
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node);
@ -917,31 +918,61 @@ namespace Nz
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(targetExprType);
// Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
UInt32 resultId = m_writer.AllocateResultId();
if (IsVectorType(swizzledExpressionType))
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
appender(exprResultId);
appender(exprResultId);
// Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
appender(exprResultId);
appender(exprResultId);
for (std::size_t i = 0; i < node.componentCount; ++i)
appender(UInt32(node.components[i]));
});
for (std::size_t i = 0; i < node.componentCount; ++i)
appender(node.components[i]);
});
}
else
{
assert(IsPrimitiveType(swizzledExpressionType));
// Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
for (std::size_t i = 0; i < node.componentCount; ++i)
appender(exprResultId);
});
}
PushResultId(resultId);
}
else
else if (IsVectorType(swizzledExpressionType))
{
assert(IsPrimitiveType(targetExprType));
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
// Extract a single component from the vector
assert(node.componentCount == 1);
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderAst::SwizzleComponent::First) );
}
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, node.components[0]);
PushResultId(resultId);
PushResultId(resultId);
}
else
{
// Swizzling a primitive to itself (a.x for example), don't do anything
assert(IsPrimitiveType(swizzledExpressionType));
assert(IsPrimitiveType(targetExprType));
assert(node.componentCount == 1);
assert(node.components[0] == 0);
PushResultId(exprResultId);
}
}
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)

View File

@ -7,6 +7,7 @@
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <numeric>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
@ -27,9 +28,57 @@ namespace Nz
{
m_block.Append(SpirvOp::OpStore, pointer.pointerId, resultId);
},
[&](const LocalVar& value)
[&](const SwizzledPointer& swizzledPointer)
{
throw std::runtime_error("not yet implemented");
if (swizzledPointer.componentCount > 1)
{
std::size_t vectorSize = swizzledPointer.swizzledType.componentCount;
UInt32 exprTypeId = m_writer.GetTypeId(swizzledPointer.swizzledType);
// Load original value (which will then be shuffled with new value)
UInt32 originalVecId = m_visitor.AllocateResultId();
m_block.Append(SpirvOp::OpLoad, exprTypeId, originalVecId, swizzledPointer.pointerId);
// Build a new composite type using OpVectorShuffle and store it
StackArray<UInt32> indices = NazaraStackArrayNoInit(UInt32, vectorSize);
std::iota(indices.begin(), indices.end(), UInt32(0u)); //< init with regular swizzle (0,1,2,3)
// override with swizzle components
for (std::size_t i = 0; i < swizzledPointer.componentCount; ++i)
indices[swizzledPointer.swizzleIndices[i]] = SafeCast<UInt32>(vectorSize + i);
UInt32 shuffleResultId = m_visitor.AllocateResultId();
m_block.AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
{
appender(exprTypeId);
appender(shuffleResultId);
appender(originalVecId);
appender(resultId);
for (UInt32 index : indices)
appender(index);
});
// Store result
m_block.Append(SpirvOp::OpStore, swizzledPointer.pointerId, shuffleResultId);
}
else
{
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node);
assert(swizzledPointer.componentCount == 1);
UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME
// Access chain
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
UInt32 pointerId = m_visitor.AllocateResultId();
m_block.Append(SpirvOp::OpAccessChain, pointerType, pointerId, swizzledPointer.pointerId, indexId);
m_block.Append(SpirvOp::OpStore, pointerId, resultId);
}
},
[](std::monostate)
{
@ -67,10 +116,6 @@ namespace Nz
m_value = Pointer { pointer.storage, resultId };
},
[&](const LocalVar& value)
{
throw std::runtime_error("not yet implemented");
},
[](std::monostate)
{
throw std::runtime_error("an internal error occurred");
@ -80,30 +125,34 @@ namespace Nz
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
{
if (node.componentCount != 1)
throw std::runtime_error("swizzle with more than one component is not yet supported");
node.expression->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
std::visit(overloaded
{
[&](const Pointer& pointer)
{
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
const auto& expressionType = GetExpressionType(*node.expression);
assert(IsVectorType(expressionType));
Int32 indexCount = UnderlyingCast(node.components[0]) - UnderlyingCast(ShaderAst::SwizzleComponent::First);
UInt32 indexId = m_writer.GetConstantId(indexCount);
SwizzledPointer swizzledPointer{ pointer };
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(expressionType);
swizzledPointer.componentCount = node.componentCount;
swizzledPointer.swizzleIndices = node.components;
m_block.Append(SpirvOp::OpAccessChain, pointerType, resultId, pointer.pointerId, indexId);
m_value = Pointer { pointer.storage, resultId };
m_value = swizzledPointer;
},
[&](const LocalVar& value)
[&](SwizzledPointer& swizzledPointer)
{
throw std::runtime_error("not yet implemented");
// Swizzle the swizzle, keep common components
std::array<UInt32, 4> newIndices;
for (std::size_t i = 0; i < node.componentCount; ++i)
{
assert(node.components[i] < node.componentCount);
newIndices[i] = swizzledPointer.swizzleIndices[node.components[i]];
}
swizzledPointer.componentCount = node.componentCount;
swizzledPointer.swizzleIndices = newIndices;
},
[](std::monostate)
{

View File

@ -359,7 +359,7 @@ namespace Nz
for (std::size_t i = 0; i < node.componentCount; ++i)
{
Int32 indexCount = UnderlyingCast(node.components[i]) - UnderlyingCast(ShaderAst::SwizzleComponent::First);
Int32 indexCount = SafeCast<Int32>(node.components[i]);
m_constantCache.Register(*m_constantCache.BuildConstant(indexCount));
}