Shader: Implement const if and const values
This commit is contained in:
parent
d679eccb43
commit
1f6937ab1b
|
|
@ -57,6 +57,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
virtual StatementPtr Clone(BranchStatement& node);
|
||||
virtual StatementPtr Clone(ConditionalStatement& node);
|
||||
virtual StatementPtr Clone(DeclareConstStatement& node);
|
||||
virtual StatementPtr Clone(DeclareExternalStatement& node);
|
||||
virtual StatementPtr Clone(DeclareFunctionStatement& node);
|
||||
virtual StatementPtr Clone(DeclareOptionStatement& node);
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ NAZARA_SHADERAST_EXPRESSION(VariableExpression)
|
|||
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
|
||||
NAZARA_SHADERAST_STATEMENT(BranchStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(ConditionalStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(DeclareConstStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement)
|
||||
NAZARA_SHADERAST_STATEMENT(DeclareOptionStatement)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void Visit(BranchStatement& node) override;
|
||||
void Visit(ConditionalStatement& node) override;
|
||||
void Visit(DeclareConstStatement& node) override;
|
||||
void Visit(DeclareExternalStatement& node) override;
|
||||
void Visit(DeclareFunctionStatement& node) override;
|
||||
void Visit(DeclareOptionStatement& node) override;
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
void Serialize(BranchStatement& node);
|
||||
void Serialize(ConditionalStatement& node);
|
||||
void Serialize(DeclareConstStatement& node);
|
||||
void Serialize(DeclareExternalStatement& node);
|
||||
void Serialize(DeclareFunctionStatement& node);
|
||||
void Serialize(DeclareOptionStatement& node);
|
||||
|
|
|
|||
|
|
@ -242,6 +242,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
std::vector<ConditionalStatement> condStatements;
|
||||
StatementPtr elseStatement;
|
||||
bool isConst = false;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API ConditionalStatement : Statement
|
||||
|
|
@ -253,6 +254,17 @@ namespace Nz::ShaderAst
|
|||
StatementPtr statement;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API DeclareConstStatement : Statement
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
void Visit(AstStatementVisitor& visitor) override;
|
||||
|
||||
std::optional<std::size_t> constIndex;
|
||||
std::string name;
|
||||
ExpressionPtr expression;
|
||||
ExpressionType type;
|
||||
};
|
||||
|
||||
struct NAZARA_SHADER_API DeclareExternalStatement : Statement
|
||||
{
|
||||
NodeType GetType() const override;
|
||||
|
|
|
|||
|
|
@ -55,14 +55,17 @@ namespace Nz::ShaderAst
|
|||
ExpressionPtr Clone(CastExpression& node) override;
|
||||
ExpressionPtr Clone(ConditionalExpression& node) override;
|
||||
ExpressionPtr Clone(ConstantExpression& node) override;
|
||||
ExpressionPtr Clone(ConstantIndexExpression& node) override;
|
||||
ExpressionPtr Clone(IdentifierExpression& node) override;
|
||||
ExpressionPtr Clone(IntrinsicExpression& node) override;
|
||||
ExpressionPtr Clone(SelectOptionExpression& node) override;
|
||||
ExpressionPtr Clone(SwizzleExpression& node) override;
|
||||
ExpressionPtr Clone(UnaryExpression& node) override;
|
||||
ExpressionPtr Clone(VariableExpression& node) override;
|
||||
|
||||
StatementPtr Clone(BranchStatement& node) override;
|
||||
StatementPtr Clone(ConditionalStatement& node) override;
|
||||
StatementPtr Clone(DeclareConstStatement& node) override;
|
||||
StatementPtr Clone(DeclareExternalStatement& node) override;
|
||||
StatementPtr Clone(DeclareFunctionStatement& node) override;
|
||||
StatementPtr Clone(DeclareOptionStatement& node) override;
|
||||
|
|
@ -84,6 +87,7 @@ namespace Nz::ShaderAst
|
|||
|
||||
template<typename T> const T& ComputeAttributeValue(AttributeValue<T>& attribute);
|
||||
ConstantValue ComputeConstantValue(Expression& expr);
|
||||
template<typename T> std::unique_ptr<T> Optimize(T& node);
|
||||
|
||||
std::size_t DeclareFunction(DeclareFunctionStatement& funcDecl);
|
||||
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ namespace Nz
|
|||
void Visit(ShaderAst::UnaryExpression& node) override;
|
||||
|
||||
void Visit(ShaderAst::BranchStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareConstStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareExternalStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareOptionStatement& node) override;
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ namespace Nz::ShaderBuilder
|
|||
inline std::unique_ptr<ShaderAst::BinaryExpression> operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const;
|
||||
};
|
||||
|
||||
template<bool Const>
|
||||
struct Branch
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::BranchStatement> operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath = nullptr) const;
|
||||
|
|
@ -70,6 +71,12 @@ namespace Nz::ShaderBuilder
|
|||
inline std::unique_ptr<ShaderAst::ConstantExpression> operator()(ShaderAst::ConstantValue value) const;
|
||||
};
|
||||
|
||||
struct DeclareConst
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::DeclareConstStatement> operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const;
|
||||
inline std::unique_ptr<ShaderAst::DeclareConstStatement> operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const;
|
||||
};
|
||||
|
||||
struct DeclareFunction
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, ShaderAst::StatementPtr statement) const;
|
||||
|
|
@ -144,12 +151,14 @@ namespace Nz::ShaderBuilder
|
|||
constexpr Impl::AccessMember AccessMember;
|
||||
constexpr Impl::Assign Assign;
|
||||
constexpr Impl::Binary Binary;
|
||||
constexpr Impl::Branch Branch;
|
||||
constexpr Impl::Branch<false> Branch;
|
||||
constexpr Impl::CallFunction CallFunction;
|
||||
constexpr Impl::Cast Cast;
|
||||
constexpr Impl::ConditionalExpression ConditionalExpression;
|
||||
constexpr Impl::ConditionalStatement ConditionalStatement;
|
||||
constexpr Impl::Constant Constant;
|
||||
constexpr Impl::Branch<true> ConstBranch;
|
||||
constexpr Impl::DeclareConst DeclareConst;
|
||||
constexpr Impl::DeclareFunction DeclareFunction;
|
||||
constexpr Impl::DeclareOption DeclareOption;
|
||||
constexpr Impl::DeclareStruct DeclareStruct;
|
||||
|
|
|
|||
|
|
@ -57,7 +57,8 @@ namespace Nz::ShaderBuilder
|
|||
return binaryNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const
|
||||
template<bool Const>
|
||||
std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch<Const>::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const
|
||||
{
|
||||
auto branchNode = std::make_unique<ShaderAst::BranchStatement>();
|
||||
|
||||
|
|
@ -66,15 +67,18 @@ namespace Nz::ShaderBuilder
|
|||
condStatement.statement = std::move(truePath);
|
||||
|
||||
branchNode->elseStatement = std::move(falsePath);
|
||||
branchNode->isConst = Const;
|
||||
|
||||
return branchNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch::operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement) const
|
||||
template<bool Const>
|
||||
std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch<Const>::operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement) const
|
||||
{
|
||||
auto branchNode = std::make_unique<ShaderAst::BranchStatement>();
|
||||
branchNode->condStatements = std::move(condStatements);
|
||||
branchNode->elseStatement = std::move(elseStatement);
|
||||
branchNode->isConst = Const;
|
||||
|
||||
return branchNode;
|
||||
}
|
||||
|
|
@ -136,6 +140,25 @@ namespace Nz::ShaderBuilder
|
|||
return constantNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const
|
||||
{
|
||||
auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>();
|
||||
declareConstNode->name = std::move(name);
|
||||
declareConstNode->expression = std::move(initialValue);
|
||||
|
||||
return declareConstNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
|
||||
{
|
||||
auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>();
|
||||
declareConstNode->name = std::move(name);
|
||||
declareConstNode->type = std::move(type);
|
||||
declareConstNode->expression = std::move(initialValue);
|
||||
|
||||
return declareConstNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::string name, ShaderAst::StatementPtr statement) const
|
||||
{
|
||||
auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>();
|
||||
|
|
|
|||
|
|
@ -81,9 +81,11 @@ namespace Nz::ShaderLang
|
|||
const Token& Peek(std::size_t advance = 0);
|
||||
|
||||
std::vector<ShaderAst::Attribute> ParseAttributes();
|
||||
void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue);
|
||||
|
||||
// Statements
|
||||
ShaderAst::StatementPtr ParseBranchStatement();
|
||||
ShaderAst::StatementPtr ParseConstStatement();
|
||||
ShaderAst::StatementPtr ParseDiscardStatement();
|
||||
ShaderAst::StatementPtr ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes = {});
|
||||
std::vector<ShaderAst::StatementPtr> ParseFunctionBody();
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ NAZARA_SHADERLANG_TOKEN(ClosingCurlyBracket)
|
|||
NAZARA_SHADERLANG_TOKEN(ClosingSquareBracket)
|
||||
NAZARA_SHADERLANG_TOKEN(Colon)
|
||||
NAZARA_SHADERLANG_TOKEN(Comma)
|
||||
NAZARA_SHADERLANG_TOKEN(Const)
|
||||
NAZARA_SHADERLANG_TOKEN(Discard)
|
||||
NAZARA_SHADERLANG_TOKEN(Divide)
|
||||
NAZARA_SHADERLANG_TOKEN(Dot)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ namespace Nz
|
|||
void Visit(ShaderAst::CallFunctionExpression& node) override;
|
||||
void Visit(ShaderAst::CastExpression& node) override;
|
||||
void Visit(ShaderAst::ConstantExpression& node) override;
|
||||
void Visit(ShaderAst::DeclareConstStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareExternalStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
|
||||
void Visit(ShaderAst::DeclareOptionStatement& node) override;
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ namespace Nz
|
|||
std::array<UInt64, ShaderStageTypeCount> shaderConditions;
|
||||
shaderConditions.fill(0);
|
||||
shaderConditions[UnderlyingCast(ShaderStageType::Fragment)] = fragmentShader->GetOptionFlagByName("HAS_DIFFUSE_TEXTURE");
|
||||
shaderConditions[UnderlyingCast(ShaderStageType::Vertex)] = vertexShader->GetOptionFlagByName("HAS_DIFFUSE_TEXTURE");
|
||||
|
||||
s_conditionIndexes.hasDiffuseMap = settings.conditions.size();
|
||||
settings.conditions.push_back({
|
||||
|
|
@ -174,6 +175,7 @@ namespace Nz
|
|||
std::array<UInt64, ShaderStageTypeCount> shaderConditions;
|
||||
shaderConditions.fill(0);
|
||||
shaderConditions[UnderlyingCast(ShaderStageType::Fragment)] = fragmentShader->GetOptionFlagByName("HAS_ALPHA_TEXTURE");
|
||||
shaderConditions[UnderlyingCast(ShaderStageType::Vertex)] = vertexShader->GetOptionFlagByName("HAS_ALPHA_TEXTURE");
|
||||
|
||||
s_conditionIndexes.hasAlphaMap = settings.conditions.size();
|
||||
settings.conditions.push_back({
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ option HAS_DIFFUSE_TEXTURE: bool;
|
|||
option HAS_ALPHA_TEXTURE: bool;
|
||||
option ALPHA_TEST: bool;
|
||||
|
||||
const HasUV = HAS_DIFFUSE_TEXTURE || HAS_ALPHA_TEXTURE;
|
||||
|
||||
[layout(std140)]
|
||||
struct BasicSettings
|
||||
{
|
||||
|
|
@ -42,7 +44,7 @@ external
|
|||
// Fragment stage
|
||||
struct FragIn
|
||||
{
|
||||
[location(0)] uv: vec2<f32>
|
||||
[location(0), cond(HasUV)] uv: vec2<f32>
|
||||
}
|
||||
|
||||
struct FragOut
|
||||
|
|
@ -68,12 +70,12 @@ fn main(input: FragIn) -> FragOut
|
|||
struct VertIn
|
||||
{
|
||||
[location(0)] pos: vec3<f32>,
|
||||
[location(1)] uv: vec2<f32>
|
||||
[location(1), cond(HasUV)] uv: vec2<f32>
|
||||
}
|
||||
|
||||
struct VertOut
|
||||
{
|
||||
[location(0)] uv: vec2<f32>,
|
||||
[location(0), cond(HasUV)] uv: vec2<f32>,
|
||||
[builtin(position)] position: vec4<f32>
|
||||
}
|
||||
|
||||
|
|
@ -81,8 +83,10 @@ struct VertOut
|
|||
fn main(input: VertIn) -> VertOut
|
||||
{
|
||||
let output: VertOut;
|
||||
output.uv = input.uv;
|
||||
output.position = viewerData.projectionMatrix * viewerData.viewMatrix * instanceData.worldMatrix * vec4<f32>(input.pos, 1.0);
|
||||
|
||||
const if (HasUV)
|
||||
output.uv = input.uv;
|
||||
|
||||
return output;
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -40,6 +40,7 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
auto clone = std::make_unique<BranchStatement>();
|
||||
clone->condStatements.reserve(node.condStatements.size());
|
||||
clone->isConst = node.isConst;
|
||||
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
|
|
@ -62,6 +63,17 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr AstCloner::Clone(DeclareConstStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareConstStatement>();
|
||||
clone->constIndex = node.constIndex;
|
||||
clone->name = node.name;
|
||||
clone->type = node.type;
|
||||
clone->expression = CloneExpression(node.expression);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
|
||||
{
|
||||
auto clone = std::make_unique<DeclareExternalStatement>();
|
||||
|
|
|
|||
|
|
@ -844,7 +844,10 @@ namespace Nz::ShaderAst
|
|||
if (!m_options.constantQueryCallback)
|
||||
return AstCloner::Clone(node);
|
||||
|
||||
return ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
|
||||
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
|
||||
constant->cachedExpressionType = GetExpressionType(constant->value);
|
||||
|
||||
return constant;
|
||||
}
|
||||
|
||||
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)
|
||||
|
|
|
|||
|
|
@ -122,6 +122,12 @@ namespace Nz::ShaderAst
|
|||
node.statement->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(DeclareConstStatement& node)
|
||||
{
|
||||
if (node.expression)
|
||||
node.expression->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(DeclareExternalStatement& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
|
|
|
|||
|
|
@ -208,6 +208,7 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
|
||||
Node(node.elseStatement);
|
||||
Value(node.isConst);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(ConditionalStatement& node)
|
||||
|
|
@ -232,6 +233,14 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(DeclareConstStatement& node)
|
||||
{
|
||||
OptVal(node.constIndex);
|
||||
Value(node.name);
|
||||
Type(node.type);
|
||||
Node(node.expression);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
|
||||
{
|
||||
Value(node.name);
|
||||
|
|
|
|||
|
|
@ -577,6 +577,18 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(ConstantIndexExpression& node)
|
||||
{
|
||||
if (node.constantId >= m_context->constantValues.size())
|
||||
throw AstError{ "invalid constant index " + std::to_string(node.constantId) };
|
||||
|
||||
// Replace by constant value
|
||||
auto constant = ShaderBuilder::Constant(m_context->constantValues[node.constantId]);
|
||||
constant->cachedExpressionType = GetExpressionType(constant->value);
|
||||
|
||||
return constant;
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
|
@ -712,11 +724,46 @@ namespace Nz::ShaderAst
|
|||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr SanitizeVisitor::Clone(VariableExpression& node)
|
||||
{
|
||||
if (node.variableId >= m_context->variableTypes.size())
|
||||
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
|
||||
|
||||
node.cachedExpressionType = m_context->variableTypes[node.variableId];
|
||||
|
||||
return AstCloner::Clone(node);
|
||||
}
|
||||
|
||||
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
|
||||
{
|
||||
if (node.isConst)
|
||||
{
|
||||
// Evaluate every condition at compilation and select the right statement
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
MandatoryExpr(cond.condition);
|
||||
|
||||
ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition));
|
||||
if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean })
|
||||
throw AstError{ "expected a boolean value" };
|
||||
|
||||
if (std::get<bool>(conditionValue))
|
||||
return AstCloner::Clone(*cond.statement);
|
||||
}
|
||||
|
||||
// Every condition failed, fallback to else if any
|
||||
if (node.elseStatement)
|
||||
return AstCloner::Clone(*node.elseStatement);
|
||||
else
|
||||
return ShaderBuilder::NoOp();
|
||||
}
|
||||
|
||||
auto clone = std::make_unique<BranchStatement>();
|
||||
clone->condStatements.reserve(node.condStatements.size());
|
||||
|
||||
if (!m_context->currentFunction)
|
||||
throw AstError{ "non-const branching statements can only exist inside a function" };
|
||||
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
PushScope();
|
||||
|
|
@ -758,6 +805,31 @@ namespace Nz::ShaderAst
|
|||
return ShaderBuilder::NoOp();
|
||||
}
|
||||
|
||||
StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node)
|
||||
{
|
||||
auto clone = static_unique_pointer_cast<DeclareConstStatement>(AstCloner::Clone(node));
|
||||
|
||||
if (!clone->expression)
|
||||
throw AstError{ "const variables must have an expression" };
|
||||
|
||||
clone->expression = Optimize(*clone->expression);
|
||||
if (clone->expression->GetType() != NodeType::ConstantExpression)
|
||||
throw AstError{ "const variable must have constant expressions " };
|
||||
|
||||
const ConstantValue& value = static_cast<ConstantExpression&>(*clone->expression).value;
|
||||
|
||||
ExpressionType expressionType = ResolveType(GetExpressionType(value));
|
||||
|
||||
if (!IsNoType(clone->type) && ResolveType(clone->type) != expressionType)
|
||||
throw AstError{ "constant expression doesn't match type" };
|
||||
|
||||
clone->type = expressionType;
|
||||
|
||||
clone->constIndex = RegisterConstant(clone->name, value);
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
|
@ -815,6 +887,9 @@ namespace Nz::ShaderAst
|
|||
|
||||
StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node)
|
||||
{
|
||||
if (m_context->currentFunction)
|
||||
throw AstError{ "a function cannot be defined inside another function" };
|
||||
|
||||
auto clone = std::make_unique<DeclareFunctionStatement>();
|
||||
clone->name = node.name;
|
||||
clone->parameters = node.parameters;
|
||||
|
|
@ -908,6 +983,9 @@ namespace Nz::ShaderAst
|
|||
|
||||
StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node)
|
||||
{
|
||||
if (m_context->currentFunction)
|
||||
throw AstError{ "options must be declared outside of functions" };
|
||||
|
||||
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
|
||||
clone->optType = ResolveType(clone->optType);
|
||||
|
||||
|
|
@ -926,6 +1004,9 @@ namespace Nz::ShaderAst
|
|||
|
||||
StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node)
|
||||
{
|
||||
if (m_context->currentFunction)
|
||||
throw AstError{ "structs must be declared outside of functions" };
|
||||
|
||||
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
|
||||
|
||||
std::unordered_set<std::string> declaredMembers;
|
||||
|
|
@ -961,6 +1042,9 @@ namespace Nz::ShaderAst
|
|||
|
||||
StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node)
|
||||
{
|
||||
if (!m_context->currentFunction)
|
||||
throw AstError{ "global variables outside of external blocks are forbidden" };
|
||||
|
||||
auto clone = static_unique_pointer_cast<DeclareVariableStatement>(AstCloner::Clone(node));
|
||||
if (IsNoType(clone->varType))
|
||||
{
|
||||
|
|
@ -1092,6 +1176,17 @@ namespace Nz::ShaderAst
|
|||
}
|
||||
|
||||
ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr)
|
||||
{
|
||||
// Run optimizer on constant value to hopefully retrieve a single constant value
|
||||
ExpressionPtr optimizedExpr = Optimize(expr);
|
||||
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
|
||||
throw AstError{"expected a constant expression"};
|
||||
|
||||
return static_cast<ConstantExpression&>(*optimizedExpr).value;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
|
||||
{
|
||||
AstOptimizer::Options optimizerOptions;
|
||||
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
|
||||
|
|
@ -1103,11 +1198,7 @@ namespace Nz::ShaderAst
|
|||
optimizerOptions.enabledOptions = m_context->options.enabledOptions;
|
||||
|
||||
// Run optimizer on constant value to hopefully retrieve a single constant value
|
||||
ExpressionPtr optimizedExpr = Optimize(expr, optimizerOptions);
|
||||
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
|
||||
throw AstError{"expected a constant expression"};
|
||||
|
||||
return static_cast<ConstantExpression&>(*optimizedExpr).value;
|
||||
return static_unique_pointer_cast<T>(ShaderAst::Optimize(node, optimizerOptions));
|
||||
}
|
||||
|
||||
std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl)
|
||||
|
|
|
|||
|
|
@ -737,6 +737,8 @@ namespace Nz
|
|||
|
||||
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
|
||||
{
|
||||
assert(!node.isConst);
|
||||
|
||||
bool first = true;
|
||||
for (const auto& statement : node.condStatements)
|
||||
{
|
||||
|
|
@ -850,6 +852,11 @@ namespace Nz
|
|||
}, node.value);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
|
|
@ -1033,9 +1040,7 @@ namespace Nz
|
|||
assert(node.varIndex);
|
||||
RegisterVariable(*node.varIndex, node.varName);
|
||||
|
||||
Append(node.varType);
|
||||
Append(" ");
|
||||
Append(node.varName);
|
||||
Append(node.varType, " ", node.varName);
|
||||
if (node.initialExpression)
|
||||
{
|
||||
Append(" = ");
|
||||
|
|
|
|||
|
|
@ -818,6 +818,10 @@ namespace Nz
|
|||
Append("min");
|
||||
break;
|
||||
|
||||
case ShaderAst::IntrinsicType::Pow:
|
||||
Append("pow");
|
||||
break;
|
||||
|
||||
case ShaderAst::IntrinsicType::SampleTexture:
|
||||
assert(!node.parameters.empty());
|
||||
Visit(node.parameters.front(), true);
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ namespace Nz::ShaderLang
|
|||
ForceCLocale forceCLocale;
|
||||
|
||||
std::unordered_map<std::string, TokenType> reservedKeywords = {
|
||||
{ "const", TokenType::Const },
|
||||
{ "discard", TokenType::Discard },
|
||||
{ "else", TokenType::Else },
|
||||
{ "external", TokenType::External },
|
||||
|
|
|
|||
|
|
@ -118,6 +118,13 @@ namespace Nz::ShaderLang
|
|||
const Token& nextToken = Peek();
|
||||
switch (nextToken.type)
|
||||
{
|
||||
case TokenType::Const:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
||||
context.root->statements.push_back(ParseConstStatement());
|
||||
break;
|
||||
|
||||
case TokenType::EndOfStream:
|
||||
if (!attributes.empty())
|
||||
throw UnexpectedToken{};
|
||||
|
|
@ -400,6 +407,28 @@ namespace Nz::ShaderLang
|
|||
return attributes;
|
||||
}
|
||||
|
||||
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue)
|
||||
{
|
||||
name = ParseIdentifierAsName();
|
||||
|
||||
if (Peek().type == TokenType::Colon)
|
||||
{
|
||||
Expect(Advance(), TokenType::Colon);
|
||||
|
||||
type = ParseType();
|
||||
}
|
||||
else
|
||||
type = ShaderAst::NoType{};
|
||||
|
||||
if (IsNoType(type) || Peek().type == TokenType::Assign)
|
||||
{
|
||||
Expect(Advance(), TokenType::Assign);
|
||||
initialValue = ParseExpression();
|
||||
}
|
||||
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseBranchStatement()
|
||||
{
|
||||
std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>();
|
||||
|
|
@ -434,9 +463,41 @@ namespace Nz::ShaderLang
|
|||
return branch;
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseConstStatement()
|
||||
{
|
||||
Expect(Advance(), TokenType::Const);
|
||||
|
||||
switch (Peek().type)
|
||||
{
|
||||
case TokenType::Identifier:
|
||||
{
|
||||
std::string constName;
|
||||
ShaderAst::ExpressionType constType;
|
||||
ShaderAst::ExpressionPtr initialValue;
|
||||
|
||||
ParseVariableDeclaration(constName, constType, initialValue);
|
||||
RegisterVariable(constName);
|
||||
|
||||
return ShaderBuilder::DeclareConst(std::move(constName), std::move(constType), std::move(initialValue));
|
||||
}
|
||||
|
||||
case TokenType::If:
|
||||
{
|
||||
auto branch = ParseBranchStatement();
|
||||
static_cast<ShaderAst::BranchStatement&>(*branch).isConst = true;
|
||||
|
||||
return branch;
|
||||
}
|
||||
|
||||
default:
|
||||
throw UnexpectedToken{};
|
||||
}
|
||||
}
|
||||
|
||||
ShaderAst::StatementPtr Parser::ParseDiscardStatement()
|
||||
{
|
||||
Expect(Advance(), TokenType::Discard);
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
|
||||
return ShaderBuilder::Discard();
|
||||
}
|
||||
|
|
@ -728,6 +789,8 @@ namespace Nz::ShaderLang
|
|||
if (Peek().type != TokenType::Semicolon)
|
||||
expr = ParseExpression();
|
||||
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
|
||||
return ShaderBuilder::Return(std::move(expr));
|
||||
}
|
||||
|
||||
|
|
@ -738,14 +801,16 @@ namespace Nz::ShaderLang
|
|||
ShaderAst::StatementPtr statement;
|
||||
switch (token.type)
|
||||
{
|
||||
case TokenType::Const:
|
||||
statement = ParseConstStatement();
|
||||
break;
|
||||
|
||||
case TokenType::Discard:
|
||||
statement = ParseDiscardStatement();
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
break;
|
||||
|
||||
case TokenType::Let:
|
||||
statement = ParseVariableDeclaration();
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
break;
|
||||
|
||||
case TokenType::Identifier:
|
||||
|
|
@ -759,7 +824,6 @@ namespace Nz::ShaderLang
|
|||
|
||||
case TokenType::Return:
|
||||
statement = ParseReturnStatement();
|
||||
Expect(Advance(), TokenType::Semicolon);
|
||||
break;
|
||||
|
||||
default:
|
||||
|
|
@ -809,23 +873,12 @@ namespace Nz::ShaderLang
|
|||
{
|
||||
Expect(Advance(), TokenType::Let);
|
||||
|
||||
std::string variableName = ParseIdentifierAsName();
|
||||
RegisterVariable(variableName);
|
||||
|
||||
ShaderAst::ExpressionType variableType = ShaderAst::NoType{};
|
||||
if (Peek().type == TokenType::Colon)
|
||||
{
|
||||
Expect(Advance(), TokenType::Colon);
|
||||
|
||||
variableType = ParseType();
|
||||
}
|
||||
|
||||
std::string variableName;
|
||||
ShaderAst::ExpressionType variableType;
|
||||
ShaderAst::ExpressionPtr expression;
|
||||
if (IsNoType(variableType) || Peek().type == TokenType::Assign)
|
||||
{
|
||||
Expect(Advance(), TokenType::Assign);
|
||||
expression = ParseExpression();
|
||||
}
|
||||
|
||||
ParseVariableDeclaration(variableName, variableType, expression);
|
||||
RegisterVariable(variableName);
|
||||
|
||||
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -584,6 +584,11 @@ namespace Nz
|
|||
}, node.value);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
|
||||
{
|
||||
assert(node.varIndex);
|
||||
|
|
|
|||
|
|
@ -86,6 +86,23 @@ namespace Nz
|
|||
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::CallFunctionExpression& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
assert(m_funcIndex);
|
||||
auto& func = m_funcs[*m_funcIndex];
|
||||
|
||||
auto& funcCall = func.funcCalls.emplace_back();
|
||||
funcCall.firstVarIndex = func.variables.size();
|
||||
|
||||
for (const auto& parameter : node.parameters)
|
||||
{
|
||||
auto& var = func.variables.emplace_back();
|
||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override
|
||||
{
|
||||
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
|
||||
|
|
@ -126,23 +143,6 @@ namespace Nz
|
|||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::CallFunctionExpression& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
assert(m_funcIndex);
|
||||
auto& func = m_funcs[*m_funcIndex];
|
||||
|
||||
auto& funcCall = func.funcCalls.emplace_back();
|
||||
funcCall.firstVarIndex = func.variables.size();
|
||||
|
||||
for (const auto& parameter : node.parameters)
|
||||
{
|
||||
auto& var = func.variables.emplace_back();
|
||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
std::optional<ShaderStageType> entryPointType;
|
||||
|
|
|
|||
|
|
@ -68,8 +68,11 @@ void CodeOutputWidget::Refresh()
|
|||
{
|
||||
shaderAst = Nz::ShaderAst::Sanitize(*shaderAst);
|
||||
|
||||
Nz::ShaderAst::AstOptimizer::Options optimOptions;
|
||||
optimOptions.enabledOptions = enabledConditions;
|
||||
|
||||
Nz::ShaderAst::AstOptimizer optimiser;
|
||||
shaderAst = optimiser.Optimise(*shaderAst, enabledConditions);
|
||||
shaderAst = optimiser.Optimise(*shaderAst, optimOptions);
|
||||
}
|
||||
|
||||
Nz::ShaderWriter::States states;
|
||||
|
|
|
|||
|
|
@ -63,12 +63,12 @@ SCENARIO("Shader generation", "[Shader]")
|
|||
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(outerStruct)));
|
||||
|
||||
auto external = std::make_unique<Nz::ShaderAst::DeclareExternalStatement>();
|
||||
external->externalVars.push_back({
|
||||
0,
|
||||
std::nullopt,
|
||||
"ubo",
|
||||
Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } }
|
||||
});
|
||||
|
||||
auto& externalVar = external->externalVars.emplace_back();
|
||||
externalVar.bindingIndex = 0;
|
||||
externalVar.name = "ubo";
|
||||
externalVar.type = Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } };
|
||||
|
||||
statements.push_back(std::move(external));
|
||||
|
||||
SECTION("Nested AccessMember")
|
||||
|
|
|
|||
Loading…
Reference in New Issue