Shader: Implement const if and const values

This commit is contained in:
Jérôme Leclercq 2021-07-07 21:38:23 +02:00
parent d679eccb43
commit 1f6937ab1b
28 changed files with 315 additions and 60 deletions

View File

@ -57,6 +57,7 @@ namespace Nz::ShaderAst
virtual StatementPtr Clone(BranchStatement& node);
virtual StatementPtr Clone(ConditionalStatement& node);
virtual StatementPtr Clone(DeclareConstStatement& node);
virtual StatementPtr Clone(DeclareExternalStatement& node);
virtual StatementPtr Clone(DeclareFunctionStatement& node);
virtual StatementPtr Clone(DeclareOptionStatement& node);

View File

@ -44,6 +44,7 @@ NAZARA_SHADERAST_EXPRESSION(VariableExpression)
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement)
NAZARA_SHADERAST_STATEMENT(ConditionalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareConstStatement)
NAZARA_SHADERAST_STATEMENT(DeclareExternalStatement)
NAZARA_SHADERAST_STATEMENT(DeclareFunctionStatement)
NAZARA_SHADERAST_STATEMENT(DeclareOptionStatement)

View File

@ -39,6 +39,7 @@ namespace Nz::ShaderAst
void Visit(BranchStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(DeclareConstStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareOptionStatement& node) override;

View File

@ -42,6 +42,7 @@ namespace Nz::ShaderAst
void Serialize(BranchStatement& node);
void Serialize(ConditionalStatement& node);
void Serialize(DeclareConstStatement& node);
void Serialize(DeclareExternalStatement& node);
void Serialize(DeclareFunctionStatement& node);
void Serialize(DeclareOptionStatement& node);

View File

@ -242,6 +242,7 @@ namespace Nz::ShaderAst
std::vector<ConditionalStatement> condStatements;
StatementPtr elseStatement;
bool isConst = false;
};
struct NAZARA_SHADER_API ConditionalStatement : Statement
@ -253,6 +254,17 @@ namespace Nz::ShaderAst
StatementPtr statement;
};
struct NAZARA_SHADER_API DeclareConstStatement : Statement
{
NodeType GetType() const override;
void Visit(AstStatementVisitor& visitor) override;
std::optional<std::size_t> constIndex;
std::string name;
ExpressionPtr expression;
ExpressionType type;
};
struct NAZARA_SHADER_API DeclareExternalStatement : Statement
{
NodeType GetType() const override;

View File

@ -55,14 +55,17 @@ namespace Nz::ShaderAst
ExpressionPtr Clone(CastExpression& node) override;
ExpressionPtr Clone(ConditionalExpression& node) override;
ExpressionPtr Clone(ConstantExpression& node) override;
ExpressionPtr Clone(ConstantIndexExpression& node) override;
ExpressionPtr Clone(IdentifierExpression& node) override;
ExpressionPtr Clone(IntrinsicExpression& node) override;
ExpressionPtr Clone(SelectOptionExpression& node) override;
ExpressionPtr Clone(SwizzleExpression& node) override;
ExpressionPtr Clone(UnaryExpression& node) override;
ExpressionPtr Clone(VariableExpression& node) override;
StatementPtr Clone(BranchStatement& node) override;
StatementPtr Clone(ConditionalStatement& node) override;
StatementPtr Clone(DeclareConstStatement& node) override;
StatementPtr Clone(DeclareExternalStatement& node) override;
StatementPtr Clone(DeclareFunctionStatement& node) override;
StatementPtr Clone(DeclareOptionStatement& node) override;
@ -84,6 +87,7 @@ namespace Nz::ShaderAst
template<typename T> const T& ComputeAttributeValue(AttributeValue<T>& attribute);
ConstantValue ComputeConstantValue(Expression& expr);
template<typename T> std::unique_ptr<T> Optimize(T& node);
std::size_t DeclareFunction(DeclareFunctionStatement& funcDecl);

View File

@ -93,6 +93,7 @@ namespace Nz
void Visit(ShaderAst::UnaryExpression& node) override;
void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareOptionStatement& node) override;

View File

@ -38,6 +38,7 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::BinaryExpression> operator()(ShaderAst::BinaryType op, ShaderAst::ExpressionPtr left, ShaderAst::ExpressionPtr right) const;
};
template<bool Const>
struct Branch
{
inline std::unique_ptr<ShaderAst::BranchStatement> operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath = nullptr) const;
@ -70,6 +71,12 @@ namespace Nz::ShaderBuilder
inline std::unique_ptr<ShaderAst::ConstantExpression> operator()(ShaderAst::ConstantValue value) const;
};
struct DeclareConst
{
inline std::unique_ptr<ShaderAst::DeclareConstStatement> operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const;
inline std::unique_ptr<ShaderAst::DeclareConstStatement> operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue = nullptr) const;
};
struct DeclareFunction
{
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> operator()(std::string name, ShaderAst::StatementPtr statement) const;
@ -144,12 +151,14 @@ namespace Nz::ShaderBuilder
constexpr Impl::AccessMember AccessMember;
constexpr Impl::Assign Assign;
constexpr Impl::Binary Binary;
constexpr Impl::Branch Branch;
constexpr Impl::Branch<false> Branch;
constexpr Impl::CallFunction CallFunction;
constexpr Impl::Cast Cast;
constexpr Impl::ConditionalExpression ConditionalExpression;
constexpr Impl::ConditionalStatement ConditionalStatement;
constexpr Impl::Constant Constant;
constexpr Impl::Branch<true> ConstBranch;
constexpr Impl::DeclareConst DeclareConst;
constexpr Impl::DeclareFunction DeclareFunction;
constexpr Impl::DeclareOption DeclareOption;
constexpr Impl::DeclareStruct DeclareStruct;

View File

@ -57,7 +57,8 @@ namespace Nz::ShaderBuilder
return binaryNode;
}
inline std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const
template<bool Const>
std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch<Const>::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr truePath, ShaderAst::StatementPtr falsePath) const
{
auto branchNode = std::make_unique<ShaderAst::BranchStatement>();
@ -66,15 +67,18 @@ namespace Nz::ShaderBuilder
condStatement.statement = std::move(truePath);
branchNode->elseStatement = std::move(falsePath);
branchNode->isConst = Const;
return branchNode;
}
inline std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch::operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement) const
template<bool Const>
std::unique_ptr<ShaderAst::BranchStatement> Impl::Branch<Const>::operator()(std::vector<ShaderAst::BranchStatement::ConditionalStatement> condStatements, ShaderAst::StatementPtr elseStatement) const
{
auto branchNode = std::make_unique<ShaderAst::BranchStatement>();
branchNode->condStatements = std::move(condStatements);
branchNode->elseStatement = std::move(elseStatement);
branchNode->isConst = Const;
return branchNode;
}
@ -136,6 +140,25 @@ namespace Nz::ShaderBuilder
return constantNode;
}
inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const
{
auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>();
declareConstNode->name = std::move(name);
declareConstNode->expression = std::move(initialValue);
return declareConstNode;
}
inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionType type, ShaderAst::ExpressionPtr initialValue) const
{
auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>();
declareConstNode->name = std::move(name);
declareConstNode->type = std::move(type);
declareConstNode->expression = std::move(initialValue);
return declareConstNode;
}
inline std::unique_ptr<ShaderAst::DeclareFunctionStatement> Impl::DeclareFunction::operator()(std::string name, ShaderAst::StatementPtr statement) const
{
auto declareFunctionNode = std::make_unique<ShaderAst::DeclareFunctionStatement>();

View File

@ -81,9 +81,11 @@ namespace Nz::ShaderLang
const Token& Peek(std::size_t advance = 0);
std::vector<ShaderAst::Attribute> ParseAttributes();
void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue);
// Statements
ShaderAst::StatementPtr ParseBranchStatement();
ShaderAst::StatementPtr ParseConstStatement();
ShaderAst::StatementPtr ParseDiscardStatement();
ShaderAst::StatementPtr ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes = {});
std::vector<ShaderAst::StatementPtr> ParseFunctionBody();

View File

@ -18,6 +18,7 @@ NAZARA_SHADERLANG_TOKEN(ClosingCurlyBracket)
NAZARA_SHADERLANG_TOKEN(ClosingSquareBracket)
NAZARA_SHADERLANG_TOKEN(Colon)
NAZARA_SHADERLANG_TOKEN(Comma)
NAZARA_SHADERLANG_TOKEN(Const)
NAZARA_SHADERLANG_TOKEN(Discard)
NAZARA_SHADERLANG_TOKEN(Divide)
NAZARA_SHADERLANG_TOKEN(Dot)

View File

@ -48,6 +48,7 @@ namespace Nz
void Visit(ShaderAst::CallFunctionExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::DeclareConstStatement& node) override;
void Visit(ShaderAst::DeclareExternalStatement& node) override;
void Visit(ShaderAst::DeclareFunctionStatement& node) override;
void Visit(ShaderAst::DeclareOptionStatement& node) override;

View File

@ -161,6 +161,7 @@ namespace Nz
std::array<UInt64, ShaderStageTypeCount> shaderConditions;
shaderConditions.fill(0);
shaderConditions[UnderlyingCast(ShaderStageType::Fragment)] = fragmentShader->GetOptionFlagByName("HAS_DIFFUSE_TEXTURE");
shaderConditions[UnderlyingCast(ShaderStageType::Vertex)] = vertexShader->GetOptionFlagByName("HAS_DIFFUSE_TEXTURE");
s_conditionIndexes.hasDiffuseMap = settings.conditions.size();
settings.conditions.push_back({
@ -174,6 +175,7 @@ namespace Nz
std::array<UInt64, ShaderStageTypeCount> shaderConditions;
shaderConditions.fill(0);
shaderConditions[UnderlyingCast(ShaderStageType::Fragment)] = fragmentShader->GetOptionFlagByName("HAS_ALPHA_TEXTURE");
shaderConditions[UnderlyingCast(ShaderStageType::Vertex)] = vertexShader->GetOptionFlagByName("HAS_ALPHA_TEXTURE");
s_conditionIndexes.hasAlphaMap = settings.conditions.size();
settings.conditions.push_back({

View File

@ -2,6 +2,8 @@ option HAS_DIFFUSE_TEXTURE: bool;
option HAS_ALPHA_TEXTURE: bool;
option ALPHA_TEST: bool;
const HasUV = HAS_DIFFUSE_TEXTURE || HAS_ALPHA_TEXTURE;
[layout(std140)]
struct BasicSettings
{
@ -42,7 +44,7 @@ external
// Fragment stage
struct FragIn
{
[location(0)] uv: vec2<f32>
[location(0), cond(HasUV)] uv: vec2<f32>
}
struct FragOut
@ -68,12 +70,12 @@ fn main(input: FragIn) -> FragOut
struct VertIn
{
[location(0)] pos: vec3<f32>,
[location(1)] uv: vec2<f32>
[location(1), cond(HasUV)] uv: vec2<f32>
}
struct VertOut
{
[location(0)] uv: vec2<f32>,
[location(0), cond(HasUV)] uv: vec2<f32>,
[builtin(position)] position: vec4<f32>
}
@ -81,8 +83,10 @@ struct VertOut
fn main(input: VertIn) -> VertOut
{
let output: VertOut;
output.uv = input.uv;
output.position = viewerData.projectionMatrix * viewerData.viewMatrix * instanceData.worldMatrix * vec4<f32>(input.pos, 1.0);
const if (HasUV)
output.uv = input.uv;
return output;
}

File diff suppressed because one or more lines are too long

View File

@ -40,6 +40,7 @@ namespace Nz::ShaderAst
{
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
clone->isConst = node.isConst;
for (auto& cond : node.condStatements)
{
@ -62,6 +63,17 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(DeclareConstStatement& node)
{
auto clone = std::make_unique<DeclareConstStatement>();
clone->constIndex = node.constIndex;
clone->name = node.name;
clone->type = node.type;
clone->expression = CloneExpression(node.expression);
return clone;
}
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareExternalStatement>();

View File

@ -844,7 +844,10 @@ namespace Nz::ShaderAst
if (!m_options.constantQueryCallback)
return AstCloner::Clone(node);
return ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
constant->cachedExpressionType = GetExpressionType(constant->value);
return constant;
}
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)

View File

@ -122,6 +122,12 @@ namespace Nz::ShaderAst
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareConstStatement& node)
{
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareExternalStatement& /*node*/)
{
/* Nothing to do */

View File

@ -208,6 +208,7 @@ namespace Nz::ShaderAst
}
Node(node.elseStatement);
Value(node.isConst);
}
void AstSerializerBase::Serialize(ConditionalStatement& node)
@ -232,6 +233,14 @@ namespace Nz::ShaderAst
}
}
void AstSerializerBase::Serialize(DeclareConstStatement& node)
{
OptVal(node.constIndex);
Value(node.name);
Type(node.type);
Node(node.expression);
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);

View File

@ -577,6 +577,18 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(ConstantIndexExpression& node)
{
if (node.constantId >= m_context->constantValues.size())
throw AstError{ "invalid constant index " + std::to_string(node.constantId) };
// Replace by constant value
auto constant = ShaderBuilder::Constant(m_context->constantValues[node.constantId]);
constant->cachedExpressionType = GetExpressionType(constant->value);
return constant;
}
ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node)
{
assert(m_context);
@ -712,11 +724,46 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(VariableExpression& node)
{
if (node.variableId >= m_context->variableTypes.size())
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
node.cachedExpressionType = m_context->variableTypes[node.variableId];
return AstCloner::Clone(node);
}
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
{
if (node.isConst)
{
// Evaluate every condition at compilation and select the right statement
for (auto& cond : node.condStatements)
{
MandatoryExpr(cond.condition);
ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition));
if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean })
throw AstError{ "expected a boolean value" };
if (std::get<bool>(conditionValue))
return AstCloner::Clone(*cond.statement);
}
// Every condition failed, fallback to else if any
if (node.elseStatement)
return AstCloner::Clone(*node.elseStatement);
else
return ShaderBuilder::NoOp();
}
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
if (!m_context->currentFunction)
throw AstError{ "non-const branching statements can only exist inside a function" };
for (auto& cond : node.condStatements)
{
PushScope();
@ -758,6 +805,31 @@ namespace Nz::ShaderAst
return ShaderBuilder::NoOp();
}
StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node)
{
auto clone = static_unique_pointer_cast<DeclareConstStatement>(AstCloner::Clone(node));
if (!clone->expression)
throw AstError{ "const variables must have an expression" };
clone->expression = Optimize(*clone->expression);
if (clone->expression->GetType() != NodeType::ConstantExpression)
throw AstError{ "const variable must have constant expressions " };
const ConstantValue& value = static_cast<ConstantExpression&>(*clone->expression).value;
ExpressionType expressionType = ResolveType(GetExpressionType(value));
if (!IsNoType(clone->type) && ResolveType(clone->type) != expressionType)
throw AstError{ "constant expression doesn't match type" };
clone->type = expressionType;
clone->constIndex = RegisterConstant(clone->name, value);
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node)
{
assert(m_context);
@ -815,6 +887,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "a function cannot be defined inside another function" };
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->name = node.name;
clone->parameters = node.parameters;
@ -908,6 +983,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "options must be declared outside of functions" };
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
clone->optType = ResolveType(clone->optType);
@ -926,6 +1004,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "structs must be declared outside of functions" };
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
std::unordered_set<std::string> declaredMembers;
@ -961,6 +1042,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node)
{
if (!m_context->currentFunction)
throw AstError{ "global variables outside of external blocks are forbidden" };
auto clone = static_unique_pointer_cast<DeclareVariableStatement>(AstCloner::Clone(node));
if (IsNoType(clone->varType))
{
@ -1092,6 +1176,17 @@ namespace Nz::ShaderAst
}
ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr)
{
// Run optimizer on constant value to hopefully retrieve a single constant value
ExpressionPtr optimizedExpr = Optimize(expr);
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
throw AstError{"expected a constant expression"};
return static_cast<ConstantExpression&>(*optimizedExpr).value;
}
template<typename T>
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
{
AstOptimizer::Options optimizerOptions;
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
@ -1103,11 +1198,7 @@ namespace Nz::ShaderAst
optimizerOptions.enabledOptions = m_context->options.enabledOptions;
// Run optimizer on constant value to hopefully retrieve a single constant value
ExpressionPtr optimizedExpr = Optimize(expr, optimizerOptions);
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
throw AstError{"expected a constant expression"};
return static_cast<ConstantExpression&>(*optimizedExpr).value;
return static_unique_pointer_cast<T>(ShaderAst::Optimize(node, optimizerOptions));
}
std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl)

View File

@ -737,6 +737,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
{
assert(!node.isConst);
bool first = true;
for (const auto& statement : node.condStatements)
{
@ -850,6 +852,11 @@ namespace Nz
}, node.value);
}
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{
/* nothing to do */
}
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
assert(node.varIndex);
@ -1033,9 +1040,7 @@ namespace Nz
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
Append(node.varType);
Append(" ");
Append(node.varName);
Append(node.varType, " ", node.varName);
if (node.initialExpression)
{
Append(" = ");

View File

@ -818,6 +818,10 @@ namespace Nz
Append("min");
break;
case ShaderAst::IntrinsicType::Pow:
Append("pow");
break;
case ShaderAst::IntrinsicType::SampleTexture:
assert(!node.parameters.empty());
Visit(node.parameters.front(), true);

View File

@ -40,6 +40,7 @@ namespace Nz::ShaderLang
ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "const", TokenType::Const },
{ "discard", TokenType::Discard },
{ "else", TokenType::Else },
{ "external", TokenType::External },

View File

@ -118,6 +118,13 @@ namespace Nz::ShaderLang
const Token& nextToken = Peek();
switch (nextToken.type)
{
case TokenType::Const:
if (!attributes.empty())
throw UnexpectedToken{};
context.root->statements.push_back(ParseConstStatement());
break;
case TokenType::EndOfStream:
if (!attributes.empty())
throw UnexpectedToken{};
@ -400,6 +407,28 @@ namespace Nz::ShaderLang
return attributes;
}
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue)
{
name = ParseIdentifierAsName();
if (Peek().type == TokenType::Colon)
{
Expect(Advance(), TokenType::Colon);
type = ParseType();
}
else
type = ShaderAst::NoType{};
if (IsNoType(type) || Peek().type == TokenType::Assign)
{
Expect(Advance(), TokenType::Assign);
initialValue = ParseExpression();
}
Expect(Advance(), TokenType::Semicolon);
}
ShaderAst::StatementPtr Parser::ParseBranchStatement()
{
std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>();
@ -434,9 +463,41 @@ namespace Nz::ShaderLang
return branch;
}
ShaderAst::StatementPtr Parser::ParseConstStatement()
{
Expect(Advance(), TokenType::Const);
switch (Peek().type)
{
case TokenType::Identifier:
{
std::string constName;
ShaderAst::ExpressionType constType;
ShaderAst::ExpressionPtr initialValue;
ParseVariableDeclaration(constName, constType, initialValue);
RegisterVariable(constName);
return ShaderBuilder::DeclareConst(std::move(constName), std::move(constType), std::move(initialValue));
}
case TokenType::If:
{
auto branch = ParseBranchStatement();
static_cast<ShaderAst::BranchStatement&>(*branch).isConst = true;
return branch;
}
default:
throw UnexpectedToken{};
}
}
ShaderAst::StatementPtr Parser::ParseDiscardStatement()
{
Expect(Advance(), TokenType::Discard);
Expect(Advance(), TokenType::Semicolon);
return ShaderBuilder::Discard();
}
@ -728,6 +789,8 @@ namespace Nz::ShaderLang
if (Peek().type != TokenType::Semicolon)
expr = ParseExpression();
Expect(Advance(), TokenType::Semicolon);
return ShaderBuilder::Return(std::move(expr));
}
@ -738,14 +801,16 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr statement;
switch (token.type)
{
case TokenType::Const:
statement = ParseConstStatement();
break;
case TokenType::Discard:
statement = ParseDiscardStatement();
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::Let:
statement = ParseVariableDeclaration();
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::Identifier:
@ -759,7 +824,6 @@ namespace Nz::ShaderLang
case TokenType::Return:
statement = ParseReturnStatement();
Expect(Advance(), TokenType::Semicolon);
break;
default:
@ -809,23 +873,12 @@ namespace Nz::ShaderLang
{
Expect(Advance(), TokenType::Let);
std::string variableName = ParseIdentifierAsName();
RegisterVariable(variableName);
ShaderAst::ExpressionType variableType = ShaderAst::NoType{};
if (Peek().type == TokenType::Colon)
{
Expect(Advance(), TokenType::Colon);
variableType = ParseType();
}
std::string variableName;
ShaderAst::ExpressionType variableType;
ShaderAst::ExpressionPtr expression;
if (IsNoType(variableType) || Peek().type == TokenType::Assign)
{
Expect(Advance(), TokenType::Assign);
expression = ParseExpression();
}
ParseVariableDeclaration(variableName, variableType, expression);
RegisterVariable(variableName);
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}

View File

@ -584,6 +584,11 @@ namespace Nz
}, node.value);
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
{
assert(node.varIndex);

View File

@ -86,6 +86,23 @@ namespace Nz
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::CallFunctionExpression& node) override
{
AstRecursiveVisitor::Visit(node);
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
auto& funcCall = func.funcCalls.emplace_back();
funcCall.firstVarIndex = func.variables.size();
for (const auto& parameter : node.parameters)
{
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
}
}
void Visit(ShaderAst::ConditionalExpression& node) override
{
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
@ -126,23 +143,6 @@ namespace Nz
}
}
void Visit(ShaderAst::CallFunctionExpression& node) override
{
AstRecursiveVisitor::Visit(node);
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
auto& funcCall = func.funcCalls.emplace_back();
funcCall.firstVarIndex = func.variables.size();
for (const auto& parameter : node.parameters)
{
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
std::optional<ShaderStageType> entryPointType;

View File

@ -68,8 +68,11 @@ void CodeOutputWidget::Refresh()
{
shaderAst = Nz::ShaderAst::Sanitize(*shaderAst);
Nz::ShaderAst::AstOptimizer::Options optimOptions;
optimOptions.enabledOptions = enabledConditions;
Nz::ShaderAst::AstOptimizer optimiser;
shaderAst = optimiser.Optimise(*shaderAst, enabledConditions);
shaderAst = optimiser.Optimise(*shaderAst, optimOptions);
}
Nz::ShaderWriter::States states;

View File

@ -63,12 +63,12 @@ SCENARIO("Shader generation", "[Shader]")
statements.push_back(Nz::ShaderBuilder::DeclareStruct(std::move(outerStruct)));
auto external = std::make_unique<Nz::ShaderAst::DeclareExternalStatement>();
external->externalVars.push_back({
0,
std::nullopt,
"ubo",
Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } }
});
auto& externalVar = external->externalVars.emplace_back();
externalVar.bindingIndex = 0;
externalVar.name = "ubo";
externalVar.type = Nz::ShaderAst::UniformType{ Nz::ShaderAst::IdentifierType{ "outerStruct" } };
statements.push_back(std::move(external));
SECTION("Nested AccessMember")