Shader: Fix SPIRV generation

This commit is contained in:
Jérôme Leclercq 2021-03-11 17:50:11 +01:00
parent fed7370e77
commit 48b93a9dea
8 changed files with 78 additions and 61 deletions

View File

@ -21,7 +21,7 @@ namespace Nz
class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept class NAZARA_SHADER_API SpirvAstVisitor : public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept
{ {
public: public:
inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks); inline SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks, ShaderAst::AstCache* cache);
SpirvAstVisitor(const SpirvAstVisitor&) = delete; SpirvAstVisitor(const SpirvAstVisitor&) = delete;
SpirvAstVisitor(SpirvAstVisitor&&) = delete; SpirvAstVisitor(SpirvAstVisitor&&) = delete;
~SpirvAstVisitor() = default; ~SpirvAstVisitor() = default;
@ -56,9 +56,10 @@ namespace Nz
void PushResultId(UInt32 value); void PushResultId(UInt32 value);
UInt32 PopResultId(); UInt32 PopResultId();
SpirvBlock* m_currentBlock; ShaderAst::AstCache* m_cache;
std::vector<SpirvBlock>& m_blocks; std::vector<SpirvBlock>& m_blocks;
std::vector<UInt32> m_resultIds; std::vector<UInt32> m_resultIds;
SpirvBlock* m_currentBlock;
SpirvWriter& m_writer; SpirvWriter& m_writer;
}; };
} }

View File

@ -7,7 +7,8 @@
namespace Nz namespace Nz
{ {
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks) : inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, std::vector<SpirvBlock>& blocks, ShaderAst::AstCache* cache) :
m_cache(cache),
m_blocks(blocks), m_blocks(blocks),
m_writer(writer) m_writer(writer)
{ {

View File

@ -172,6 +172,7 @@ namespace Nz
SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept; SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept;
static ConstantPtr BuildConstant(const ShaderConstantValue& value); static ConstantPtr BuildConstant(const ShaderConstantValue& value);
static TypePtr BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector<ShaderAst::ShaderExpressionType>& parameters);
static TypePtr BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass); static TypePtr BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass);
static TypePtr BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass); static TypePtr BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass);
static TypePtr BuildType(const ShaderAst::BasicType& type); static TypePtr BuildType(const ShaderAst::BasicType& type);

View File

@ -56,10 +56,8 @@ namespace Nz
void AppendHeader(); void AppendHeader();
SpirvConstantCache::Function BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters);
UInt32 GetConstantId(const ShaderConstantValue& value) const; UInt32 GetConstantId(const ShaderConstantValue& value) const;
UInt32 GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters); UInt32 GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode);
const ExtVar& GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const; const ExtVar& GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const;
const ExtVar& GetInputVariable(const std::string& name) const; const ExtVar& GetInputVariable(const std::string& name) const;
const ExtVar& GetOutputVariable(const std::string& name) const; const ExtVar& GetOutputVariable(const std::string& name) const;
@ -81,12 +79,13 @@ namespace Nz
std::optional<UInt32> ReadVariable(const ExtVar& var, OnlyCache); std::optional<UInt32> ReadVariable(const ExtVar& var, OnlyCache);
UInt32 RegisterConstant(const ShaderConstantValue& value); UInt32 RegisterConstant(const ShaderConstantValue& value);
UInt32 RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters); UInt32 RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode);
UInt32 RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass); UInt32 RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass);
UInt32 RegisterType(ShaderAst::ShaderExpressionType type); UInt32 RegisterType(ShaderAst::ShaderExpressionType type);
void WriteLocalVariable(std::string name, UInt32 resultId); void WriteLocalVariable(std::string name, UInt32 resultId);
static SpirvConstantCache::TypePtr BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode);
static void MergeSections(std::vector<UInt32>& output, const SpirvSection& from); static void MergeSections(std::vector<UInt32>& output, const SpirvSection& from);
struct Context struct Context

View File

@ -163,6 +163,8 @@ namespace Nz::ShaderAst
void ExpressionTypeVisitor::Visit(IdentifierExpression& node) void ExpressionTypeVisitor::Visit(IdentifierExpression& node)
{ {
assert(m_cache);
auto scopeIt = m_cache->scopeIdByNode.find(&node); auto scopeIt = m_cache->scopeIdByNode.find(&node);
if (scopeIt == m_cache->scopeIdByNode.end()) if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error"); throw std::runtime_error("internal error");

View File

@ -39,13 +39,13 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node) void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
{ {
ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node); ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsBasicType(resultExprType)); assert(IsBasicType(resultExprType));
ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left); ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache);
assert(IsBasicType(leftExprType)); assert(IsBasicType(leftExprType));
ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right); ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache);
assert(IsBasicType(rightExprType)); assert(IsBasicType(rightExprType));
ShaderAst::BasicType resultType = std::get<ShaderAst::BasicType>(resultExprType); ShaderAst::BasicType resultType = std::get<ShaderAst::BasicType>(resultExprType);
@ -582,7 +582,7 @@ namespace Nz
{ {
case ShaderAst::IntrinsicType::DotProduct: case ShaderAst::IntrinsicType::DotProduct:
{ {
const ShaderAst::ShaderExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); ShaderAst::ShaderExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache);
assert(IsBasicType(vecExprType)); assert(IsBasicType(vecExprType));
ShaderAst::BasicType vecType = std::get<ShaderAst::BasicType>(vecExprType); ShaderAst::BasicType vecType = std::get<ShaderAst::BasicType>(vecExprType);
@ -626,7 +626,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{ {
const ShaderAst::ShaderExpressionType& targetExprType = ShaderAst::GetExpressionType(node); ShaderAst::ShaderExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsBasicType(targetExprType)); assert(IsBasicType(targetExprType));
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType); ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);

View File

@ -570,18 +570,32 @@ namespace Nz
}, value)); }, value));
} }
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector<ShaderAst::ShaderExpressionType>& parameters) -> TypePtr
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameterType : parameters)
parameterTypes.push_back(BuildPointerType(parameterType, SpirvStorageClass::Function));
return std::make_shared<Type>(Function{
BuildType(retType),
std::move(parameterTypes)
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
{ {
return std::make_shared<Type>(SpirvConstantCache::Pointer{ return std::make_shared<Type>(Pointer{
SpirvConstantCache::BuildType(type), BuildType(type),
storageClass storageClass
}); });
} }
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
{ {
return std::make_shared<Type>(SpirvConstantCache::Pointer{ return std::make_shared<Type>(Pointer{
SpirvConstantCache::BuildType(type), BuildType(type),
storageClass storageClass
}); });
} }

View File

@ -31,6 +31,7 @@ namespace Nz
public: public:
using ExtInstList = std::unordered_set<std::string>; using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>; using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) : PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_cache(cache), m_cache(cache),
@ -79,9 +80,15 @@ namespace Nz
void Visit(ShaderAst::DeclareFunctionStatement& node) override void Visit(ShaderAst::DeclareFunctionStatement& node) override
{ {
m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType)); funcs.emplace_back(node);
std::vector<ShaderAst::ShaderExpressionType> parameterTypes;
for (auto& parameter : node.parameters) for (auto& parameter : node.parameters)
m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type)); parameterTypes.push_back(parameter.type);
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
AstRecursiveVisitor::Visit(node);
} }
void Visit(ShaderAst::DeclareStructStatement& node) override void Visit(ShaderAst::DeclareStructStatement& node) override
@ -92,14 +99,14 @@ namespace Nz
void Visit(ShaderAst::DeclareVariableStatement& node) override void Visit(ShaderAst::DeclareVariableStatement& node) override
{ {
variableTypes.insert(node.varType); m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
AstRecursiveVisitor::Visit(node); AstRecursiveVisitor::Visit(node);
} }
void Visit(ShaderAst::IdentifierExpression& node) override void Visit(ShaderAst::IdentifierExpression& node) override
{ {
variableTypes.insert(GetExpressionType(node, m_cache)); m_constantCache.Register(*SpirvConstantCache::BuildType(GetExpressionType(node, m_cache)));
AstRecursiveVisitor::Visit(node); AstRecursiveVisitor::Visit(node);
} }
@ -122,7 +129,7 @@ namespace Nz
} }
ExtInstList extInsts; ExtInstList extInsts;
LocalContainer variableTypes; FunctionContainer funcs;
private: private:
ShaderAst::AstCache* m_cache; ShaderAst::AstCache* m_cache;
@ -210,8 +217,6 @@ namespace Nz
m_currentState = nullptr; m_currentState = nullptr;
}); });
std::vector<ShaderAst::StatementPtr> functionStatements;
ShaderAst::AstCloner cloner; ShaderAst::AstCloner cloner;
// Register all extended instruction sets // Register all extended instruction sets
@ -345,26 +350,22 @@ namespace Nz
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex); state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0); state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0);
} }
} }*/
for (const auto& func : shader.GetFunctions()) for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
{ {
auto& funcData = state.funcs.emplace_back(); auto& funcData = state.funcs.emplace_back();
funcData.id = AllocateResultId(); funcData.id = AllocateResultId();
funcData.typeId = GetFunctionTypeId(func.returnType, func.parameters); funcData.typeId = GetFunctionTypeId(func);
state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name); state.debugInfo.Append(SpirvOp::OpName, funcData.id, func.name);
} }
std::size_t entryPointIndex = std::numeric_limits<std::size_t>::max(); std::size_t funcIndex = 0;
for (std::size_t funcIndex = 0; funcIndex < shader.GetFunctionCount(); ++funcIndex) for (const ShaderAst::DeclareFunctionStatement& func : preVisitor.funcs)
{ {
const auto& func = shader.GetFunction(funcIndex); auto& funcData = state.funcs[funcIndex++];
if (func.name == "main")
entryPointIndex = funcIndex;
auto& funcData = state.funcs[funcIndex];
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId); state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
@ -386,8 +387,9 @@ namespace Nz
state.parameterIds.emplace(param.name, std::move(parameterData)); state.parameterIds.emplace(param.name, std::move(parameterData));
} }
SpirvAstVisitor visitor(*this, state.functionBlocks); SpirvAstVisitor visitor(*this, state.functionBlocks, &m_context.cache);
visitor.Visit(functionStatements[funcIndex]); for (const auto& statement : func.statements)
statement->Visit(visitor);
if (!state.functionBlocks.back().IsTerminated()) if (!state.functionBlocks.back().IsTerminated())
{ {
@ -405,7 +407,7 @@ namespace Nz
AppendHeader(); AppendHeader();
if (entryPointIndex != std::numeric_limits<std::size_t>::max()) /*if (entryPointIndex != std::numeric_limits<std::size_t>::max())
{ {
SpvExecutionModel execModel; SpvExecutionModel execModel;
const auto& entryFuncData = shader.GetFunction(entryPointIndex); const auto& entryFuncData = shader.GetFunction(entryPointIndex);
@ -415,11 +417,11 @@ namespace Nz
switch (m_context.shader->GetStage()) switch (m_context.shader->GetStage())
{ {
case ShaderStageType::Fragment: case ShaderStageType::Fragment:
execModel = SpvExecutionModelFragment; execModel = SpirvExecutionModel::Fragment;
break; break;
case ShaderStageType::Vertex: case ShaderStageType::Vertex:
execModel = SpvExecutionModelVertex; execModel = SpirvExecutionModel::Vertex;
break; break;
default: default:
@ -445,15 +447,15 @@ namespace Nz
}); });
if (m_context.shader->GetStage() == ShaderStageType::Fragment) if (m_context.shader->GetStage() == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft); state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);
}*/ }*/
std::vector<UInt32> ret; std::vector<UInt32> ret;
/*MergeSections(ret, state.header); MergeSections(ret, state.header);
MergeSections(ret, state.debugInfo); MergeSections(ret, state.debugInfo);
MergeSections(ret, state.annotations); MergeSections(ret, state.annotations);
MergeSections(ret, state.constants); MergeSections(ret, state.constants);
MergeSections(ret, state.instructions);*/ MergeSections(ret, state.instructions);
return ret; return ret;
} }
@ -479,26 +481,12 @@ namespace Nz
m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count) m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count)
m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now) m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now)
m_currentState->header.Append(SpirvOp::OpCapability, SpvCapabilityShader); m_currentState->header.Append(SpirvOp::OpCapability, SpirvCapability::Shader);
for (const auto& [extInst, resultId] : m_currentState->extensionInstructions) for (const auto& [extInst, resultId] : m_currentState->extensionInstructions)
m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst); m_currentState->header.Append(SpirvOp::OpExtInstImport, resultId, extInst);
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
}
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function));
return SpirvConstantCache::Function{
SpirvConstantCache::BuildType(retType),
std::move(parameterTypes)
};
} }
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
@ -506,9 +494,9 @@ namespace Nz
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value)); return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
} }
UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters) UInt32 SpirvWriter::GetFunctionTypeId(const ShaderAst::DeclareFunctionStatement& functionNode)
{ {
return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) }); return m_currentState->constantTypeCache.GetId({ *BuildFunctionType(functionNode) });
} }
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar& auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
@ -644,9 +632,9 @@ namespace Nz
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value)); return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
} }
UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters) UInt32 SpirvWriter::RegisterFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{ {
return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) }); return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
} }
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass) UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass)
@ -666,6 +654,17 @@ namespace Nz
m_currentState->varToResult.insert_or_assign(std::move(name), resultId); m_currentState->varToResult.insert_or_assign(std::move(name), resultId);
} }
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
std::vector<ShaderAst::ShaderExpressionType> parameterTypes;
parameterTypes.reserve(functionNode.parameters.size());
for (const auto& parameter : functionNode.parameters)
parameterTypes.push_back(parameter.type);
return SpirvConstantCache::BuildFunctionType(functionNode.returnType, parameterTypes);
}
void SpirvWriter::MergeSections(std::vector<UInt32>& output, const SpirvSection& from) void SpirvWriter::MergeSections(std::vector<UInt32>& output, const SpirvSection& from)
{ {
const std::vector<UInt32>& bytecode = from.GetBytecode(); const std::vector<UInt32>& bytecode = from.GetBytecode();