Handle shader options of any type

This commit is contained in:
Jérôme Leclercq
2021-09-03 19:33:41 +02:00
parent 2f9e495739
commit 02a12d9328
38 changed files with 236 additions and 1118 deletions

View File

@@ -116,7 +116,7 @@ namespace Nz::ShaderAst
StatementPtr AstCloner::Clone(DeclareOptionStatement& node)
{
auto clone = std::make_unique<DeclareOptionStatement>();
clone->initialValue = CloneExpression(node.initialValue);
clone->defaultValue = CloneExpression(node.defaultValue);
clone->optIndex = node.optIndex;
clone->optName = node.optName;
clone->optType = node.optType;

View File

@@ -697,7 +697,9 @@ namespace Nz::ShaderAst
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
if constexpr (std::is_same_v<T, NoValue>)
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
constantValues.push_back(arg);
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i32>)
{
@@ -815,9 +817,6 @@ namespace Nz::ShaderAst
ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node)
{
if (!m_options.enabledOptions)
return AstCloner::Clone(node);
auto cond = CloneExpression(node.condition);
if (cond->GetType() != NodeType::ConstantValueExpression)
throw std::runtime_error("conditional expression condition must be a constant expression");
@@ -884,9 +883,6 @@ namespace Nz::ShaderAst
StatementPtr AstOptimizer::Clone(ConditionalStatement& node)
{
if (!m_options.enabledOptions)
return AstCloner::Clone(node);
auto cond = CloneExpression(node.condition);
if (cond->GetType() != NodeType::ConstantValueExpression)
throw std::runtime_error("conditional expression condition must be a constant expression");

View File

@@ -134,8 +134,8 @@ namespace Nz::ShaderAst
void AstRecursiveVisitor::Visit(DeclareOptionStatement& node)
{
if (node.initialValue)
node.initialValue->Visit(*this);
if (node.defaultValue)
node.defaultValue->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareStructStatement& /*node*/)

View File

@@ -140,19 +140,20 @@ namespace Nz::ShaderAst
Value(value);
};
static_assert(std::variant_size_v<decltype(node.value)> == 10);
static_assert(std::variant_size_v<decltype(node.value)> == 11);
switch (typeIndex)
{
case 0: SerializeValue(bool()); break;
case 1: SerializeValue(float()); break;
case 2: SerializeValue(Int32()); break;
case 3: SerializeValue(UInt32()); break;
case 4: SerializeValue(Vector2f()); break;
case 5: SerializeValue(Vector3f()); break;
case 6: SerializeValue(Vector4f()); break;
case 7: SerializeValue(Vector2i32()); break;
case 8: SerializeValue(Vector3i32()); break;
case 9: SerializeValue(Vector4i32()); break;
case 0: break;
case 1: SerializeValue(bool()); break;
case 2: SerializeValue(float()); break;
case 3: SerializeValue(Int32()); break;
case 4: SerializeValue(UInt32()); break;
case 5: SerializeValue(Vector2f()); break;
case 6: SerializeValue(Vector3f()); break;
case 7: SerializeValue(Vector4f()); break;
case 8: SerializeValue(Vector2i32()); break;
case 9: SerializeValue(Vector3i32()); break;
case 10: SerializeValue(Vector4i32()); break;
default: throw std::runtime_error("unexpected data type");
}
}
@@ -261,7 +262,7 @@ namespace Nz::ShaderAst
OptVal(node.optIndex);
Value(node.optName);
Type(node.optType);
Node(node.initialValue);
Node(node.defaultValue);
}
void AstSerializerBase::Serialize(DeclareStructStatement& node)

View File

@@ -13,7 +13,9 @@ namespace Nz::ShaderAst
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
if constexpr (std::is_same_v<T, NoValue>)
return NoType{};
else if constexpr (std::is_same_v<T, bool>)
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return PrimitiveType::Float32;

View File

@@ -583,6 +583,9 @@ namespace Nz::ShaderAst
ExpressionPtr SanitizeVisitor::Clone(ConstantValueExpression& node)
{
if (std::holds_alternative<NoValue>(node.value))
throw std::runtime_error("expected a value");
auto clone = static_unique_pointer_cast<ConstantValueExpression>(AstCloner::Clone(node));
clone->cachedExpressionType = GetExpressionType(clone->value);
@@ -980,12 +983,17 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
clone->optType = ResolveType(clone->optType);
if (clone->initialValue && clone->optType != GetExpressionType(*clone->initialValue))
throw AstError{ "option " + clone->optName + " initial expression must be of the same type than the option" };
if (clone->defaultValue && clone->optType != GetExpressionType(*clone->defaultValue))
throw AstError{ "option " + clone->optName + " default expression must be of the same type than the option" };
std::size_t optionIndex = m_context->nextOptionIndex++;
clone->optIndex = RegisterConstant(clone->optName, TestBit(m_context->options.enabledOptions, optionIndex));
if (auto optionValueIt = m_context->options.optionValues.find(optionIndex); optionValueIt != m_context->options.optionValues.end())
clone->optIndex = RegisterConstant(clone->optName, optionValueIt->second);
else if (clone->defaultValue)
clone->optIndex = RegisterConstant(clone->optName, ComputeConstantValue(*clone->defaultValue));
else
throw AstError{ "missing option " + clone->optName + " value (has no default value)" };
if (m_context->options.removeOptionDeclaration)
return ShaderBuilder::NoOp();
@@ -1186,8 +1194,6 @@ namespace Nz::ShaderAst
return m_context->constantValues[constantId];
};
optimizerOptions.enabledOptions = m_context->options.enabledOptions;
// Run optimizer on constant value to hopefully retrieve a single constant value
return static_unique_pointer_cast<T>(ShaderAst::Optimize(node, optimizerOptions));
}

View File

@@ -107,7 +107,6 @@ namespace Nz
std::optional<ShaderStageType> selectedStage;
std::unordered_map<std::size_t, FunctionData> functions;
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
UInt64 enabledOptions = 0;
};
struct Builtin
@@ -139,6 +138,7 @@ namespace Nz
std::optional<ShaderStageType> stage;
std::stringstream stream;
std::unordered_map<std::size_t, ShaderAst::ConstantValue> optionValues;
std::unordered_map<std::size_t, ShaderAst::StructDescription*> structs;
std::unordered_map<std::size_t, std::string> variableNames;
std::vector<InOutField> inputFields;
@@ -147,7 +147,6 @@ namespace Nz
const GlslWriter::BindingMapping& bindingMapping;
PreVisitor previsitor;
const States* states = nullptr;
UInt64 enabledOptions = 0;
bool isInEntryPoint = false;
unsigned int indentLevel = 0;
};
@@ -155,7 +154,7 @@ namespace Nz
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states)
{
State state(bindingMapping);
state.enabledOptions = states.enabledOptions;
state.optionValues = states.optionValues;
state.stage = shaderStage;
m_currentState = &state;
@@ -168,7 +167,7 @@ namespace Nz
ShaderAst::Statement* targetAst;
if (!states.sanitized)
{
sanitizedAst = Sanitize(shader, states.enabledOptions);
sanitizedAst = Sanitize(shader, states.optionValues);
targetAst = sanitizedAst.get();
}
else
@@ -182,7 +181,6 @@ namespace Nz
targetAst = optimizedAst.get();
}
state.previsitor.enabledOptions = states.enabledOptions;
state.previsitor.selectedStage = shaderStage;
targetAst->Visit(state.previsitor);
@@ -203,11 +201,11 @@ namespace Nz
return s_flipYUniformName;
}
ShaderAst::StatementPtr GlslWriter::Sanitize(ShaderAst::Statement& ast, UInt64 enabledOptions, std::string* error)
ShaderAst::StatementPtr GlslWriter::Sanitize(ShaderAst::Statement& ast, std::unordered_map<std::size_t, ShaderAst::ConstantValue> optionValues, std::string* error)
{
// Always sanitize for reserved identifiers
ShaderAst::SanitizeVisitor::Options options;
options.enabledOptions = enabledOptions;
options.optionValues = std::move(optionValues);
options.makeVariableNameUnique = true;
options.reservedIdentifiers = {
// All reserved GLSL keywords as of GLSL ES 3.2
@@ -837,7 +835,9 @@ namespace Nz
if constexpr (std::is_same_v<T, Vector2i32> || std::is_same_v<T, Vector3i32> || std::is_same_v<T, Vector4i32>)
Append("i"); //< for ivec
if constexpr (std::is_same_v<T, bool>)
if constexpr (std::is_same_v<T, ShaderAst::NoValue>)
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, bool>)
Append((arg) ? "true" : "false");
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
Append(std::to_string(arg));

View File

@@ -643,7 +643,9 @@ namespace Nz
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
if constexpr (std::is_same_v<T, ShaderAst::NoValue>)
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, bool>)
Append((arg) ? "true" : "false");
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
Append(std::to_string(arg));
@@ -733,10 +735,10 @@ namespace Nz
RegisterConstant(*node.optIndex, node.optName);
Append("option ", node.optName, ": ", node.optType);
if (node.initialValue)
if (node.defaultValue)
{
Append(" = ");
node.initialValue->Visit(*this);
node.defaultValue->Visit(*this);
}
Append(";");

View File

@@ -418,7 +418,9 @@ namespace Nz
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
if constexpr (std::is_same_v<T, ShaderAst::NoValue>)
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, bool>)
return ConstantBool{ arg };
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
return ConstantScalar{ arg };

View File

@@ -463,7 +463,7 @@ namespace Nz
if (!states.sanitized)
{
ShaderAst::SanitizeVisitor::Options options;
options.enabledOptions = states.enabledOptions;
options.optionValues = states.optionValues;
sanitizedAst = ShaderAst::Sanitize(shader, options);
targetAst = sanitizedAst.get();
@@ -653,11 +653,6 @@ namespace Nz
return m_currentState->constantTypeCache.GetId(*m_currentState->constantTypeCache.BuildType(type));
}
bool SpirvWriter::IsOptionEnabled(std::size_t optionIndex) const
{
return TestBit<Nz::UInt64>(m_context.states->enabledOptions, optionIndex);
}
UInt32 SpirvWriter::RegisterConstant(const ShaderAst::ConstantValue& value)
{
return m_currentState->constantTypeCache.Register(*m_currentState->constantTypeCache.BuildConstant(value));