Shader: Add initial support for options (WIP)

This commit is contained in:
Jérôme Leclercq
2021-04-17 14:43:00 +02:00
parent 7da02c8708
commit 87ce2edc6e
45 changed files with 586 additions and 113 deletions

View File

@@ -85,9 +85,9 @@ namespace Nz
return std::make_shared<OpenGLShaderModule>(*this, shaderStages, shaderAst, states);
}
std::shared_ptr<ShaderModule> OpenGLDevice::InstantiateShaderModule(ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const void* source, std::size_t sourceSize)
std::shared_ptr<ShaderModule> OpenGLDevice::InstantiateShaderModule(ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const void* source, std::size_t sourceSize, const ShaderWriter::States& states)
{
return std::make_shared<OpenGLShaderModule>(*this, shaderStages, lang, source, sourceSize);
return std::make_shared<OpenGLShaderModule>(*this, shaderStages, lang, source, sourceSize, states);
}
std::shared_ptr<Texture> OpenGLDevice::InstantiateTexture(const TextureInfo& params)

View File

@@ -21,7 +21,7 @@ namespace Nz
Create(device, shaderStages, shaderAst, states);
}
OpenGLShaderModule::OpenGLShaderModule(OpenGLDevice& device, ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const void* source, std::size_t sourceSize)
OpenGLShaderModule::OpenGLShaderModule(OpenGLDevice& device, ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const void* source, std::size_t sourceSize, const ShaderWriter::States& states)
{
NazaraAssert(shaderStages != 0, "at least one shader stage must be specified");
@@ -65,7 +65,7 @@ namespace Nz
Nz::ShaderLang::Parser parser;
Nz::ShaderAst::StatementPtr shaderAst = parser.Parse(tokens);
Create(device, shaderStages, shaderAst, {});
Create(device, shaderStages, shaderAst, states);
break;
}

View File

@@ -11,7 +11,7 @@ namespace Nz
{
RenderDevice::~RenderDevice() = default;
std::shared_ptr<ShaderModule> RenderDevice::InstantiateShaderModule(ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const std::filesystem::path& sourcePath)
std::shared_ptr<ShaderModule> RenderDevice::InstantiateShaderModule(ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const std::filesystem::path& sourcePath, const ShaderWriter::States& states)
{
File file(sourcePath);
if (!file.Open(OpenMode_ReadOnly | OpenMode_Text))
@@ -29,6 +29,6 @@ namespace Nz
return {};
}
return InstantiateShaderModule(shaderStages, lang, source.data(), source.size());
return InstantiateShaderModule(shaderStages, lang, source.data(), source.size(), states);
}
}

View File

@@ -56,7 +56,7 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(ConditionalStatement& node)
{
auto clone = std::make_unique<ConditionalStatement>();
clone->conditionName = node.conditionName;
clone->optionIndex = node.optionIndex;
clone->statement = CloneStatement(node.statement);
return clone;
@@ -77,6 +77,7 @@ namespace Nz::ShaderAst
clone->entryStage = node.entryStage;
clone->funcIndex = node.funcIndex;
clone->name = node.name;
clone->optionName = node.optionName;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->varIndex = node.varIndex;
@@ -88,6 +89,17 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(DeclareOptionStatement& node)
{
auto clone = std::make_unique<DeclareOptionStatement>();
clone->initialValue = CloneExpression(node.initialValue);
clone->optIndex = node.optIndex;
clone->optName = node.optName;
clone->optType = node.optType;
return clone;
}
StatementPtr AstCloner::Clone(DeclareStructStatement& node)
{
auto clone = std::make_unique<DeclareStructStatement>();
@@ -212,7 +224,7 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(ConditionalExpression& node)
{
auto clone = std::make_unique<ConditionalExpression>();
clone->conditionName = node.conditionName;
clone->optionIndex = node.optionIndex;
clone->falsePath = CloneExpression(node.falsePath);
clone->truePath = CloneExpression(node.truePath);
@@ -255,6 +267,18 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(SelectOptionExpression& node)
{
auto clone = std::make_unique<SelectOptionExpression>();
clone->optionName = node.optionName;
clone->falsePath = CloneExpression(node.falsePath);
clone->truePath = CloneExpression(node.truePath);
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(SwizzleExpression& node)
{
auto clone = std::make_unique<SwizzleExpression>();

View File

@@ -62,6 +62,12 @@ namespace Nz::ShaderAst
param->Visit(*this);
}
void AstRecursiveVisitor::Visit(SelectOptionExpression& node)
{
node.truePath->Visit(*this);
node.falsePath->Visit(*this);
}
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{
node.expression->Visit(*this);
@@ -100,6 +106,12 @@ namespace Nz::ShaderAst
statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareOptionStatement& node)
{
if (node.initialValue)
node.initialValue->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareStructStatement& /*node*/)
{
/* Nothing to do */

View File

@@ -74,7 +74,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(ConditionalExpression& node)
{
Value(node.conditionName);
SizeT(node.optionIndex);
Node(node.truePath);
Node(node.falsePath);
}
@@ -113,14 +113,6 @@ namespace Nz::ShaderAst
}
}
void AstSerializerBase::Serialize(DeclareVariableStatement& node)
{
OptVal(node.varIndex);
Value(node.varName);
Type(node.varType);
Node(node.initialExpression);
}
void AstSerializerBase::Serialize(IdentifierExpression& node)
{
Value(node.identifier);
@@ -134,6 +126,13 @@ namespace Nz::ShaderAst
Node(param);
}
void AstSerializerBase::Serialize(SelectOptionExpression& node)
{
Value(node.optionName);
Node(node.truePath);
Node(node.falsePath);
}
void AstSerializerBase::Serialize(SwizzleExpression& node)
{
SizeT(node.componentCount);
@@ -163,7 +162,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(ConditionalStatement& node)
{
Value(node.conditionName);
SizeT(node.optionIndex);
Node(node.statement);
}
@@ -186,6 +185,7 @@ namespace Nz::ShaderAst
Type(node.returnType);
OptEnum(node.entryStage);
OptVal(node.funcIndex);
Value(node.optionName);
OptVal(node.varIndex);
Container(node.parameters);
@@ -200,6 +200,14 @@ namespace Nz::ShaderAst
Node(statement);
}
void AstSerializerBase::Serialize(DeclareOptionStatement& node)
{
OptVal(node.optIndex);
Value(node.optName);
Type(node.optType);
Node(node.initialValue);
}
void AstSerializerBase::Serialize(DeclareStructStatement& node)
{
OptVal(node.structIndex);
@@ -216,6 +224,14 @@ namespace Nz::ShaderAst
OptVal(member.locationIndex);
}
}
void AstSerializerBase::Serialize(DeclareVariableStatement& node)
{
OptVal(node.varIndex);
Value(node.varName);
Type(node.varType);
Node(node.initialExpression);
}
void AstSerializerBase::Serialize(DiscardStatement& /*node*/)
{

View File

@@ -67,6 +67,20 @@ namespace Nz::ShaderAst
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(SelectOptionExpression& node)
{
node.truePath->Visit(*this);
ExpressionCategory trueExprCategory = m_expressionCategory;
node.falsePath->Visit(*this);
ExpressionCategory falseExprCategory = m_expressionCategory;
if (trueExprCategory == ExpressionCategory::RValue || falseExprCategory == ExpressionCategory::RValue)
m_expressionCategory = ExpressionCategory::RValue;
else
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{
node.expression->Visit(*this);

View File

@@ -29,4 +29,35 @@ namespace Nz::ShaderAst
}
#include <Nazara/Shader/Ast/AstNodeList.hpp>
ExpressionType ConstantExpression::GetExpressionType() const
{
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return PrimitiveType::Float32;
else if constexpr (std::is_same_v<T, Int32>)
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, Vector2f>)
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)
return VectorType{ 3, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector4f>)
return VectorType{ 4, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector2i32>)
return VectorType{ 2, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector3i32>)
return VectorType{ 3, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector4i32>)
return VectorType{ 4, PrimitiveType::Int32 };
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, value);
}
}

View File

@@ -5,6 +5,7 @@
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Core/StackArray.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <stdexcept>
#include <unordered_set>
@@ -28,16 +29,18 @@ namespace Nz::ShaderAst
struct SanitizeVisitor::Context
{
Options options;
std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<unsigned int> usedBindingIndexes;
};
StatementPtr SanitizeVisitor::Sanitize(StatementPtr& nodePtr, std::string* error)
StatementPtr SanitizeVisitor::Sanitize(StatementPtr& nodePtr, const Options& options, std::string* error)
{
StatementPtr clone;
Context currentContext;
currentContext.options = options;
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
@@ -483,6 +486,33 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(SelectOptionExpression& node)
{
MandatoryExpr(node.truePath);
MandatoryExpr(node.falsePath);
auto condExpr = std::make_unique<ConditionalExpression>();
condExpr->truePath = CloneExpression(node.truePath);
condExpr->falsePath = CloneExpression(node.falsePath);
const Identifier* identifier = FindIdentifier(node.optionName);
if (!identifier)
throw AstError{ "unknown option " + node.optionName };
if (!std::holds_alternative<Option>(identifier->value))
throw AstError{ "expected option identifier" };
condExpr->optionIndex = std::get<Option>(identifier->value).optionIndex;
const ExpressionType& leftExprType = GetExpressionType(*condExpr->truePath);
if (leftExprType != GetExpressionType(*condExpr->falsePath))
throw AstError{ "true path type must match false path type" };
condExpr->cachedExpressionType = leftExprType;
return condExpr;
}
ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node)
{
if (node.componentCount > 4)
@@ -585,9 +615,13 @@ namespace Nz::ShaderAst
{
extVar.type = ResolveType(extVar.type);
ExpressionType varType = extVar.type;
ExpressionType varType;
if (IsUniformType(extVar.type))
varType = std::get<StructType>(std::get<UniformType>(varType).containedType);
varType = std::get<StructType>(std::get<UniformType>(extVar.type).containedType);
else if (IsSamplerType(extVar.type))
varType = extVar.type;
else
throw AstError{ "External variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType));
if (!clone->varIndex)
@@ -616,6 +650,7 @@ namespace Nz::ShaderAst
clone->entryStage = node.entryStage;
clone->name = node.name;
clone->funcIndex = m_nextFuncIndex++;
clone->optionName = node.optionName;
clone->parameters = node.parameters;
clone->returnType = ResolveType(node.returnType);
@@ -635,6 +670,36 @@ namespace Nz::ShaderAst
}
PopScope();
if (!clone->optionName.empty())
{
const Identifier* identifier = FindIdentifier(node.optionName);
if (!identifier)
throw AstError{ "unknown option " + node.optionName };
if (!std::holds_alternative<Option>(identifier->value))
throw AstError{ "expected option identifier" };
std::size_t optionIndex = std::get<Option>(identifier->value).optionIndex;
return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone));
}
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node)
{
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
clone->optType = ResolveType(clone->optType);
if (clone->initialValue && clone->optType != GetExpressionType(*clone->initialValue))
throw AstError{ "option " + clone->optName + " initial expression must be of the same type than the option" };
clone->optIndex = RegisterOption(clone->optName, clone->optType);
if (m_context->options.removeOptionDeclaration)
return ShaderBuilder::NoOp();
return clone;
}

View File

@@ -35,6 +35,12 @@ namespace Nz
{
using AstRecursiveVisitor::Visit;
void Visit(ShaderAst::ConditionalStatement& node) override
{
if (TestBit<UInt64>(enabledOptions, node.optionIndex))
node.statement->Visit(*this);
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
// Dismiss function if it's an entry point of another type than the one selected
@@ -46,6 +52,7 @@ namespace Nz
if (stage != *selectedStage)
return;
assert(!entryPoint);
entryPoint = &node;
}
}
@@ -58,6 +65,7 @@ namespace Nz
std::optional<ShaderStageType> selectedStage;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
UInt64 enabledOptions = 0;
};
struct Builtin
@@ -88,13 +96,15 @@ namespace Nz
std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields;
std::vector<InOutField> outputFields;
UInt64 enabledOptions = 0;
bool isInEntryPoint = false;
unsigned int indentLevel = 0;
};
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::StatementPtr& shader, const States& conditions)
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::StatementPtr& shader, const States& states)
{
State state;
state.enabledOptions = states.enabledOptions;
state.stage = shaderStage;
m_currentState = &state;
@@ -106,6 +116,7 @@ namespace Nz
ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader);
PreVisitor previsitor;
previsitor.enabledOptions = states.enabledOptions;
previsitor.selectedStage = shaderStage;
sanitizedAst->Visit(previsitor);
@@ -574,22 +585,16 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::ConditionalExpression& node)
{
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
if (TestBit<Nz::UInt64>(m_currentState->enabledOptions, node.optionIndex))
Visit(node.truePath);
else
Visit(node.falsePath);*/
Visit(node.falsePath);
}
void GlslWriter::Visit(ShaderAst::ConditionalStatement& node)
{
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
Visit(node.statement);*/
if (TestBit<Nz::UInt64>(m_currentState->enabledOptions, node.optionIndex))
node.statement->Visit(*this);
}
void GlslWriter::Visit(ShaderAst::ConstantExpression& node)

View File

@@ -40,13 +40,15 @@ namespace Nz::ShaderLang
ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "external", TokenType::External },
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "return", TokenType::Return },
{ "struct", TokenType::Struct },
{ "true", TokenType::BoolTrue }
{ "external", TokenType::External },
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "option", TokenType::Option },
{ "return", TokenType::Return },
{ "select_opt", TokenType::SelectOpt },
{ "struct", TokenType::Struct },
{ "true", TokenType::BoolTrue }
};
std::size_t currentPos = 0;

View File

@@ -28,7 +28,8 @@ namespace Nz::ShaderLang
{ "builtin", ShaderAst::AttributeType::Builtin },
{ "entry", ShaderAst::AttributeType::Entry },
{ "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
{ "location", ShaderAst::AttributeType::Location },
{ "opt", ShaderAst::AttributeType::Option },
};
std::unordered_map<std::string, ShaderStageType> s_entryPoints = {
@@ -90,6 +91,13 @@ namespace Nz::ShaderLang
assert(attributes.empty());
attributes = ParseAttributes();
break;
case TokenType::Option:
if (!attributes.empty())
throw UnexpectedToken{};
context.root->statements.push_back(ParseOptionDeclaration());
break;
case TokenType::FunctionDeclaration:
context.root->statements.push_back(ParseFunctionDeclaration(std::move(attributes)));
@@ -450,14 +458,15 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingCurlyBracket);
std::optional<ShaderStageType> entryPoint;
auto func = ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
for (const auto& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::Entry:
{
if (entryPoint)
if (func->entryStage)
throw AttributeError{ "attribute entry must be present once" };
if (!std::holds_alternative<std::string>(arg))
@@ -469,7 +478,19 @@ namespace Nz::ShaderLang
if (it == s_entryPoints.end())
throw AttributeError{ ("invalid parameter " + argStr + " for entry attribute").c_str() };
entryPoint = it->second;
func->entryStage = it->second;
break;
}
case ShaderAst::AttributeType::Option:
{
if (!func->optionName.empty())
throw AttributeError{ "attribute option must be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AttributeError{ "attribute option requires a string parameter" };
func->optionName = std::get<std::string>(arg);
break;
}
@@ -477,8 +498,8 @@ namespace Nz::ShaderLang
throw AttributeError{ "unhandled attribute for function" };
}
}
return ShaderBuilder::DeclareFunction(entryPoint, std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
return func;
}
ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter()
@@ -492,6 +513,29 @@ namespace Nz::ShaderLang
return { parameterName, parameterType };
}
ShaderAst::StatementPtr Parser::ParseOptionDeclaration()
{
Expect(Advance(), TokenType::Option);
std::string optionName = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon);
ShaderAst::ExpressionType optionType = ParseType();
ShaderAst::ExpressionPtr initialValue;
if (Peek().type == TokenType::Assign)
{
Consume();
initialValue = ParseExpression();
}
Expect(Advance(), TokenType::Semicolon);
return ShaderBuilder::DeclareOption(std::move(optionName), std::move(optionType), std::move(initialValue));
}
ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::Struct);
@@ -871,12 +915,35 @@ namespace Nz::ShaderLang
break;
case TokenType::OpenParenthesis:
return ParseParenthesisExpression();
case TokenType::SelectOpt:
return ParseSelectOptExpression();
default:
throw UnexpectedToken{};
}
}
ShaderAst::ExpressionPtr Parser::ParseSelectOptExpression()
{
Expect(Advance(), TokenType::SelectOpt);
Expect(Advance(), TokenType::OpenParenthesis);
std::string optionName = ParseIdentifierAsName();
Expect(Advance(), TokenType::Comma);
ShaderAst::ExpressionPtr trueExpr = ParseExpression();
Expect(Advance(), TokenType::Comma);
ShaderAst::ExpressionPtr falseExpr = ParseExpression();
Expect(Advance(), TokenType::ClosingParenthesis);
return ShaderBuilder::SelectOption(std::move(optionName), std::move(trueExpr), std::move(falseExpr));
}
ShaderAst::AttributeType Parser::ParseIdentifierAsAttributeType()
{

View File

@@ -502,7 +502,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
if (m_writer.IsOptionEnabled(node.optionIndex))
node.truePath->Visit(*this);
else
node.falsePath->Visit(*this);
@@ -510,7 +510,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::ConditionalStatement& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
if (m_writer.IsOptionEnabled(node.optionIndex))
node.statement->Visit(*this);
}

View File

@@ -56,7 +56,7 @@ namespace Nz
using StructContainer = std::vector<ShaderAst::StructDescription>;
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
m_conditions(conditions),
m_states(conditions),
m_constantCache(constantCache),
m_externalBlockIndex(0),
m_funcs(funcs)
@@ -80,24 +80,18 @@ namespace Nz
void Visit(ShaderAst::ConditionalExpression& node) override
{
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
Visit(node.truePath);
if (TestBit<Nz::UInt64>(m_states.enabledOptions, node.optionIndex))
node.truePath->Visit(*this);
else
Visit(node.falsePath);*/
node.falsePath->Visit(*this);
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::ConditionalStatement& node) override
{
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
Visit(node.statement);*/
if (TestBit<Nz::UInt64>(m_states.enabledOptions, node.optionIndex))
node.statement->Visit(*this);
}
void Visit(ShaderAst::ConstantExpression& node) override
@@ -347,7 +341,7 @@ namespace Nz
StructContainer declaredStructs;
private:
const SpirvWriter::States& m_conditions;
const SpirvWriter::States& m_states;
SpirvConstantCache& m_constantCache;
std::optional<std::size_t> m_funcIndex;
std::size_t m_externalBlockIndex;
@@ -390,11 +384,11 @@ namespace Nz
{
}
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& states)
{
ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader);
m_context.states = &conditions;
m_context.states = &states;
State state;
m_currentState = &state;
@@ -404,7 +398,7 @@ namespace Nz
});
// Register all extended instruction sets
PreVisitor preVisitor(conditions, state.constantTypeCache, state.funcs);
PreVisitor preVisitor(states, state.constantTypeCache, state.funcs);
sanitizedAst->Visit(preVisitor);
m_currentState->preVisitor = &preVisitor;
@@ -559,6 +553,11 @@ namespace Nz
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type));
}
bool SpirvWriter::IsOptionEnabled(std::size_t optionIndex) const
{
return TestBit<Nz::UInt64>(m_context.states->enabledOptions, optionIndex);
}
UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value));

View File

@@ -60,10 +60,10 @@ namespace Nz
return stage;
}
std::shared_ptr<ShaderModule> VulkanDevice::InstantiateShaderModule(ShaderStageTypeFlags stages, ShaderLanguage lang, const void* source, std::size_t sourceSize)
std::shared_ptr<ShaderModule> VulkanDevice::InstantiateShaderModule(ShaderStageTypeFlags stages, ShaderLanguage lang, const void* source, std::size_t sourceSize, const ShaderWriter::States& states)
{
auto stage = std::make_shared<VulkanShaderModule>();
if (!stage->Create(*this, stages, lang, source, sourceSize))
if (!stage->Create(*this, stages, lang, source, sourceSize, states))
throw std::runtime_error("failed to instanciate vulkan shader module");
return stage;

View File

@@ -65,10 +65,10 @@ namespace Nz
writer.SetEnv(env);
std::vector<UInt32> code = writer.Generate(shaderAst, states);
return Create(device, shaderStages, ShaderLanguage::SpirV, code.data(), code.size() * sizeof(UInt32));
return Create(device, shaderStages, ShaderLanguage::SpirV, code.data(), code.size() * sizeof(UInt32), {});
}
bool VulkanShaderModule::Create(Vk::Device& device, ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const void* source, std::size_t sourceSize)
bool VulkanShaderModule::Create(Vk::Device& device, ShaderStageTypeFlags shaderStages, ShaderLanguage lang, const void* source, std::size_t sourceSize, const ShaderWriter::States& states)
{
switch (lang)
{
@@ -89,7 +89,7 @@ namespace Nz
Nz::ShaderLang::Parser parser;
Nz::ShaderAst::StatementPtr shaderAst = parser.Parse(tokens);
return Create(device, shaderStages, shaderAst, {});
return Create(device, shaderStages, shaderAst, states);
}
case ShaderLanguage::SpirV:

View File

@@ -46,7 +46,7 @@ Nz::ShaderAst::NodePtr ConditionalExpression::BuildNode(Nz::ShaderAst::Expressio
const ShaderGraph& graph = GetGraph();
const auto& conditionEntry = graph.GetCondition(*m_currentConditionIndex);
return Nz::ShaderBuilder::ConditionalExpression(conditionEntry.name, std::move(expressions[0]), std::move(expressions[1]));
return Nz::ShaderBuilder::SelectOption(conditionEntry.name, std::move(expressions[0]), std::move(expressions[1]));
}
QString ConditionalExpression::caption() const

View File

@@ -454,6 +454,10 @@ Nz::ShaderAst::StatementPtr ShaderGraph::ToAst() const
{
std::vector<Nz::ShaderAst::StatementPtr> statements;
// Declare all options
for (const auto& condition : m_conditions)
statements.push_back(Nz::ShaderBuilder::DeclareOption(condition.name, Nz::ShaderAst::PrimitiveType::Boolean));
// Declare all structures
for (const auto& structInfo : m_structs)
{