Rework shader AST (WIP)
This commit is contained in:
@@ -26,155 +26,131 @@ namespace Nz
|
||||
{
|
||||
namespace
|
||||
{
|
||||
class PreVisitor : public ShaderAstRecursiveVisitor, public ShaderVarVisitor
|
||||
class PreVisitor : public ShaderAst::AstRecursiveVisitor
|
||||
{
|
||||
public:
|
||||
using BuiltinContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::BuiltinVariable>>;
|
||||
using ExtInstList = std::unordered_set<std::string>;
|
||||
using LocalContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::LocalVariable>>;
|
||||
using ParameterContainer = std::unordered_set< std::shared_ptr<const ShaderNodes::ParameterVariable>>;
|
||||
using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>;
|
||||
|
||||
PreVisitor(const ShaderAst& shader, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_shader(shader),
|
||||
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_cache(cache),
|
||||
m_conditions(conditions),
|
||||
m_constantCache(constantCache)
|
||||
{
|
||||
}
|
||||
|
||||
using ShaderAstRecursiveVisitor::Visit;
|
||||
using ShaderVarVisitor::Visit;
|
||||
|
||||
void Visit(ShaderNodes::AccessMember& node) override
|
||||
void Visit(ShaderAst::AccessMemberExpression& node) override
|
||||
{
|
||||
for (std::size_t index : node.memberIndices)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));
|
||||
/*for (std::size_t index : node.memberIdentifiers)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override
|
||||
{
|
||||
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
Visit(node.falsePath);*/
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override
|
||||
void Visit(ShaderAst::ConditionalStatement& node) override
|
||||
{
|
||||
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
|
||||
Visit(node.statement);
|
||||
Visit(node.statement);*/
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Constant& node) override
|
||||
void Visit(ShaderAst::ConstantExpression& node) override
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
|
||||
}, node.value);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
Visit(node.variable);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType));
|
||||
for (auto& parameter : node.parameters)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Identifier& node) override
|
||||
void Visit(ShaderAst::DeclareStructStatement& node) override
|
||||
{
|
||||
Visit(node.var);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
for (auto& field : node.description.members)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(field.type));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::IntrinsicCall& node) override
|
||||
void Visit(ShaderAst::DeclareVariableStatement& node) override
|
||||
{
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
variableTypes.insert(node.varType);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IdentifierExpression& node) override
|
||||
{
|
||||
variableTypes.insert(GetExpressionType(node, m_cache));
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IntrinsicExpression& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
// Require GLSL.std.450
|
||||
case ShaderNodes::IntrinsicType::CrossProduct:
|
||||
case ShaderAst::IntrinsicType::CrossProduct:
|
||||
extInsts.emplace("GLSL.std.450");
|
||||
break;
|
||||
|
||||
// Part of SPIR-V core
|
||||
case ShaderNodes::IntrinsicType::DotProduct:
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::BuiltinVariable& var) override
|
||||
{
|
||||
builtinVars.insert(std::static_pointer_cast<const ShaderNodes::BuiltinVariable>(var.shared_from_this()));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::InputVariable& /*var*/) override
|
||||
{
|
||||
/* Handled by ShaderAst */
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::LocalVariable& var) override
|
||||
{
|
||||
localVars.insert(std::static_pointer_cast<const ShaderNodes::LocalVariable>(var.shared_from_this()));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::OutputVariable& /*var*/) override
|
||||
{
|
||||
/* Handled by ShaderAst */
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ParameterVariable& var) override
|
||||
{
|
||||
paramVars.insert(std::static_pointer_cast<const ShaderNodes::ParameterVariable>(var.shared_from_this()));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::UniformVariable& /*var*/) override
|
||||
{
|
||||
/* Handled by ShaderAst */
|
||||
}
|
||||
|
||||
BuiltinContainer builtinVars;
|
||||
ExtInstList extInsts;
|
||||
LocalContainer localVars;
|
||||
ParameterContainer paramVars;
|
||||
LocalContainer variableTypes;
|
||||
|
||||
private:
|
||||
const ShaderAst& m_shader;
|
||||
ShaderAst::AstCache* m_cache;
|
||||
const SpirvWriter::States& m_conditions;
|
||||
SpirvConstantCache& m_constantCache;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
constexpr ShaderNodes::BasicType GetBasicType()
|
||||
constexpr ShaderAst::BasicType GetBasicType()
|
||||
{
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return ShaderNodes::BasicType::Boolean;
|
||||
return ShaderAst::BasicType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return(ShaderNodes::BasicType::Float1);
|
||||
return(ShaderAst::BasicType::Float1);
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return(ShaderNodes::BasicType::Int1);
|
||||
return(ShaderAst::BasicType::Int1);
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return(ShaderNodes::BasicType::Float2);
|
||||
return(ShaderAst::BasicType::Float2);
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return(ShaderNodes::BasicType::Float3);
|
||||
return(ShaderAst::BasicType::Float3);
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return(ShaderNodes::BasicType::Float4);
|
||||
return(ShaderAst::BasicType::Float4);
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return(ShaderNodes::BasicType::Int2);
|
||||
return(ShaderAst::BasicType::Int2);
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return(ShaderNodes::BasicType::Int3);
|
||||
return(ShaderAst::BasicType::Int3);
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return(ShaderNodes::BasicType::Int4);
|
||||
return(ShaderAst::BasicType::Int4);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "unhandled type");
|
||||
}
|
||||
@@ -198,7 +174,7 @@ namespace Nz
|
||||
tsl::ordered_map<std::string, ExtVar> parameterIds;
|
||||
tsl::ordered_map<std::string, ExtVar> uniformIds;
|
||||
std::unordered_map<std::string, UInt32> extensionInstructions;
|
||||
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
|
||||
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
|
||||
std::unordered_map<std::string, UInt32> varToResult;
|
||||
std::vector<Func> funcs;
|
||||
std::vector<SpirvBlock> functionBlocks;
|
||||
@@ -219,13 +195,12 @@ namespace Nz
|
||||
{
|
||||
}
|
||||
|
||||
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader, const States& conditions)
|
||||
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
|
||||
{
|
||||
std::string error;
|
||||
if (!ValidateShader(shader, &error))
|
||||
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.shader = &shader;
|
||||
m_context.states = &conditions;
|
||||
|
||||
State state;
|
||||
@@ -235,23 +210,19 @@ namespace Nz
|
||||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
std::vector<ShaderNodes::StatementPtr> functionStatements;
|
||||
std::vector<ShaderAst::StatementPtr> functionStatements;
|
||||
|
||||
ShaderAstCloner cloner;
|
||||
|
||||
PreVisitor preVisitor(shader, conditions, state.constantTypeCache);
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
functionStatements.emplace_back(cloner.Clone(func.statement));
|
||||
preVisitor.Visit(func.statement);
|
||||
}
|
||||
ShaderAst::AstCloner cloner;
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
|
||||
shader->Visit(preVisitor);
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
state.extensionInstructions[extInst] = AllocateResultId();
|
||||
|
||||
// Register all types
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
/*for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
RegisterType(func.returnType);
|
||||
for (const auto& param : func.parameters)
|
||||
@@ -270,8 +241,8 @@ namespace Nz
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
RegisterFunctionType(func.returnType, func.parameters);
|
||||
|
||||
for (const auto& local : preVisitor.localVars)
|
||||
RegisterType(local->type);
|
||||
for (const auto& type : preVisitor.variableTypes)
|
||||
RegisterType(type);
|
||||
|
||||
for (const auto& builtin : preVisitor.builtinVars)
|
||||
RegisterType(builtin->type);
|
||||
@@ -283,7 +254,7 @@ namespace Nz
|
||||
SpirvBuiltIn builtinDecoration;
|
||||
switch (builtin->entry)
|
||||
{
|
||||
case ShaderNodes::BuiltinEntry::VertexPosition:
|
||||
case ShaderAst::BuiltinEntry::VertexPosition:
|
||||
variable.debugName = "builtin_VertexPosition";
|
||||
variable.storageClass = SpirvStorageClass::Output;
|
||||
|
||||
@@ -294,10 +265,10 @@ namespace Nz
|
||||
throw std::runtime_error("unexpected builtin type");
|
||||
}
|
||||
|
||||
const ShaderExpressionType& builtinExprType = builtin->type;
|
||||
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
|
||||
assert(IsBasicType(builtinExprType));
|
||||
|
||||
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);
|
||||
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
|
||||
|
||||
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
|
||||
|
||||
@@ -420,7 +391,7 @@ namespace Nz
|
||||
|
||||
if (!state.functionBlocks.back().IsTerminated())
|
||||
{
|
||||
assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void));
|
||||
assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void));
|
||||
state.functionBlocks.back().Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
@@ -475,14 +446,14 @@ namespace Nz
|
||||
|
||||
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
|
||||
}
|
||||
}*/
|
||||
|
||||
std::vector<UInt32> ret;
|
||||
MergeSections(ret, state.header);
|
||||
/*MergeSections(ret, state.header);
|
||||
MergeSections(ret, state.debugInfo);
|
||||
MergeSections(ret, state.annotations);
|
||||
MergeSections(ret, state.constants);
|
||||
MergeSections(ret, state.instructions);
|
||||
MergeSections(ret, state.instructions);*/
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -516,16 +487,16 @@ namespace Nz
|
||||
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
|
||||
}
|
||||
|
||||
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
|
||||
{
|
||||
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
|
||||
parameterTypes.reserve(parameters.size());
|
||||
|
||||
for (const auto& parameter : parameters)
|
||||
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function));
|
||||
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function));
|
||||
|
||||
return SpirvConstantCache::Function{
|
||||
SpirvConstantCache::BuildType(*m_context.shader, retType),
|
||||
SpirvConstantCache::BuildType(retType),
|
||||
std::move(parameterTypes)
|
||||
};
|
||||
}
|
||||
@@ -535,12 +506,12 @@ namespace Nz
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) });
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar&
|
||||
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->builtinIds.find(builtin);
|
||||
assert(it != m_currentState->builtinIds.end());
|
||||
@@ -572,14 +543,14 @@ namespace Nz
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type));
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
|
||||
@@ -673,20 +644,20 @@ namespace Nz
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterType(ShaderExpressionType type)
|
||||
UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type)
|
||||
{
|
||||
assert(m_currentState);
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type));
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)
|
||||
|
||||
Reference in New Issue
Block a user