Shader: Add basic support for Branch node in spir-v

This commit is contained in:
Jérôme Leclercq
2021-01-04 10:27:08 +01:00
parent 44bc86d082
commit 4d63d6e022
16 changed files with 186 additions and 48 deletions

View File

@@ -22,7 +22,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::AccessMember& node)
{
SpirvExpressionLoad accessMemberVisitor(m_writer);
SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock);
PushResultId(accessMemberVisitor.Evaluate(node));
}
@@ -30,7 +30,7 @@ namespace Nz
{
UInt32 resultId = EvaluateExpression(node.right);
SpirvExpressionStore storeVisitor(m_writer);
SpirvExpressionStore storeVisitor(m_writer, *m_currentBlock);
storeVisitor.Store(node.left, resultId);
PushResultId(resultId);
@@ -438,10 +438,63 @@ namespace Nz
if (swapOperands)
std::swap(leftOperand, rightOperand);
m_writer.GetInstructions().Append(op, m_writer.GetTypeId(resultType), resultId, leftOperand, rightOperand);
m_currentBlock->Append(op, m_writer.GetTypeId(resultType), resultId, leftOperand, rightOperand);
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::Branch& node)
{
assert(!node.condStatements.empty());
auto& firstCond = node.condStatements.front();
UInt32 previousConditionId = EvaluateExpression(firstCond.condition);
SpirvBlock previousContentBlock(m_writer);
m_currentBlock = &previousContentBlock;
Visit(firstCond.statement);
std::optional<std::size_t> nextBlock;
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
{
const auto& statement = node.condStatements[statementIndex];
SpirvBlock contentBlock(m_writer);
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), contentBlock.GetLabelId());
previousConditionId = EvaluateExpression(statement.condition);
m_blocks.emplace_back(std::move(previousContentBlock));
previousContentBlock = std::move(contentBlock);
m_currentBlock = &previousContentBlock;
Visit(statement.statement);
}
SpirvBlock mergeBlock(m_writer);
if (node.elseStatement)
{
SpirvBlock elseBlock(m_writer);
m_currentBlock = &elseBlock;
Visit(node.elseStatement);
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), elseBlock.GetLabelId());
m_blocks.emplace_back(std::move(previousContentBlock));
m_blocks.emplace_back(std::move(elseBlock));
}
else
{
m_blocks.back().Append(SpirvOp::OpBranchConditional, previousConditionId, previousContentBlock.GetLabelId(), mergeBlock.GetLabelId());
m_blocks.emplace_back(std::move(previousContentBlock));
}
m_blocks.emplace_back(std::move(mergeBlock));
m_currentBlock = &m_blocks.back();
}
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
@@ -461,7 +514,7 @@ namespace Nz
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
@@ -508,7 +561,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::Discard& /*node*/)
{
m_writer.GetInstructions().Append(SpirvOp::OpKill);
m_currentBlock->Append(SpirvOp::OpKill);
}
void SpirvAstVisitor::Visit(ShaderNodes::ExpressionStatement& node)
@@ -519,7 +572,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::Identifier& node)
{
SpirvExpressionLoad loadVisitor(m_writer);
SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
@@ -541,7 +594,7 @@ namespace Nz
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2);
m_currentBlock->Append(SpirvOp::OpDot, typeId, resultId, vec1, vec2);
PushResultId(resultId);
break;
}
@@ -560,7 +613,7 @@ namespace Nz
UInt32 coordinatesId = EvaluateExpression(node.coordinates);
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
PushResultId(resultId);
}
@@ -583,7 +636,7 @@ namespace Nz
if (node.componentCount > 1)
{
// Swizzling is implemented via SpirvOp::OpVectorShuffle using the same vector twice as operands
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
{
appender(m_writer.GetTypeId(targetType));
appender(resultId);
@@ -599,7 +652,7 @@ namespace Nz
// Extract a single component from the vector
assert(node.componentCount == 1);
m_writer.GetInstructions().Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
}
PushResultId(resultId);

View File

@@ -682,7 +682,7 @@ namespace Nz
using ConstantType = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<ConstantType, ConstantBool>)
constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, resultId);
constants.Append((arg.value) ? SpirvOp::OpConstantTrue : SpirvOp::OpConstantFalse, GetId({ Bool{} }), resultId);
else if constexpr (std::is_same_v<ConstantType, ConstantComposite>)
{
constants.AppendVariadic(SpirvOp::OpConstantComposite, [&](const auto& appender)

View File

@@ -4,7 +4,7 @@
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
@@ -26,7 +26,7 @@ namespace Nz
{
UInt32 resultId = m_writer.AllocateResultId();
m_writer.GetInstructions().Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId);
m_block.Append(SpirvOp::OpLoad, pointer.pointedTypeId, resultId, pointer.resultId);
return resultId;
},
@@ -53,7 +53,7 @@ namespace Nz
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
UInt32 typeId = m_writer.GetTypeId(node.exprType);
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
@@ -70,7 +70,7 @@ namespace Nz
UInt32 resultId = m_writer.AllocateResultId();
UInt32 typeId = m_writer.GetTypeId(node.exprType);
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
{
appender(typeId);
appender(resultId);

View File

@@ -3,7 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvExpressionStore.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <Nazara/Shader/Debug.hpp>
@@ -23,7 +23,7 @@ namespace Nz
{
[&](const Pointer& pointer)
{
m_writer.GetInstructions().Append(SpirvOp::OpStore, pointer.resultId, resultId);
m_block.Append(SpirvOp::OpStore, pointer.resultId, resultId);
},
[&](const LocalVar& value)
{
@@ -47,7 +47,7 @@ namespace Nz
UInt32 resultId = m_writer.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
m_writer.GetInstructions().AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);

View File

@@ -8,6 +8,7 @@
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/SpirvData.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
@@ -395,23 +396,29 @@ namespace Nz
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
state.instructions.Append(SpirvOp::OpLabel, AllocateResultId());
std::vector<SpirvBlock> blocks;
blocks.emplace_back(*this);
for (const auto& param : func.parameters)
{
UInt32 paramResultId = AllocateResultId();
funcData.paramsId.push_back(paramResultId);
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
blocks.back().Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
}
SpirvAstVisitor visitor(*this);
SpirvAstVisitor visitor(*this, blocks);
visitor.Visit(functionStatements[funcIndex]);
if (func.returnType == ShaderNodes::BasicType::Void)
state.instructions.Append(SpirvOp::OpReturn);
blocks.back().Append(SpirvOp::OpReturn);
else
throw std::runtime_error("returning values from functions is not yet supported"); //< TODO
state.instructions.Append(SpirvOp::OpFunctionEnd);
blocks.back().Append(SpirvOp::OpFunctionEnd);
for (SpirvBlock& block : blocks)
state.instructions.Append(block);
}
assert(entryPointIndex != std::numeric_limits<std::size_t>::max());
@@ -552,11 +559,6 @@ namespace Nz
return it.value();
}
SpirvSection& SpirvWriter::GetInstructions()
{
return m_currentState->instructions;
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));