Shader: Implement const if and const values

This commit is contained in:
Jérôme Leclercq
2021-07-07 21:38:23 +02:00
parent d679eccb43
commit 1f6937ab1b
28 changed files with 315 additions and 60 deletions

View File

@@ -161,6 +161,7 @@ namespace Nz
std::array<UInt64, ShaderStageTypeCount> shaderConditions;
shaderConditions.fill(0);
shaderConditions[UnderlyingCast(ShaderStageType::Fragment)] = fragmentShader->GetOptionFlagByName("HAS_DIFFUSE_TEXTURE");
shaderConditions[UnderlyingCast(ShaderStageType::Vertex)] = vertexShader->GetOptionFlagByName("HAS_DIFFUSE_TEXTURE");
s_conditionIndexes.hasDiffuseMap = settings.conditions.size();
settings.conditions.push_back({
@@ -174,6 +175,7 @@ namespace Nz
std::array<UInt64, ShaderStageTypeCount> shaderConditions;
shaderConditions.fill(0);
shaderConditions[UnderlyingCast(ShaderStageType::Fragment)] = fragmentShader->GetOptionFlagByName("HAS_ALPHA_TEXTURE");
shaderConditions[UnderlyingCast(ShaderStageType::Vertex)] = vertexShader->GetOptionFlagByName("HAS_ALPHA_TEXTURE");
s_conditionIndexes.hasAlphaMap = settings.conditions.size();
settings.conditions.push_back({

View File

@@ -2,6 +2,8 @@ option HAS_DIFFUSE_TEXTURE: bool;
option HAS_ALPHA_TEXTURE: bool;
option ALPHA_TEST: bool;
const HasUV = HAS_DIFFUSE_TEXTURE || HAS_ALPHA_TEXTURE;
[layout(std140)]
struct BasicSettings
{
@@ -42,7 +44,7 @@ external
// Fragment stage
struct FragIn
{
[location(0)] uv: vec2<f32>
[location(0), cond(HasUV)] uv: vec2<f32>
}
struct FragOut
@@ -68,12 +70,12 @@ fn main(input: FragIn) -> FragOut
struct VertIn
{
[location(0)] pos: vec3<f32>,
[location(1)] uv: vec2<f32>
[location(1), cond(HasUV)] uv: vec2<f32>
}
struct VertOut
{
[location(0)] uv: vec2<f32>,
[location(0), cond(HasUV)] uv: vec2<f32>,
[builtin(position)] position: vec4<f32>
}
@@ -81,8 +83,10 @@ struct VertOut
fn main(input: VertIn) -> VertOut
{
let output: VertOut;
output.uv = input.uv;
output.position = viewerData.projectionMatrix * viewerData.viewMatrix * instanceData.worldMatrix * vec4<f32>(input.pos, 1.0);
const if (HasUV)
output.uv = input.uv;
return output;
}

File diff suppressed because one or more lines are too long

View File

@@ -40,6 +40,7 @@ namespace Nz::ShaderAst
{
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
clone->isConst = node.isConst;
for (auto& cond : node.condStatements)
{
@@ -62,6 +63,17 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(DeclareConstStatement& node)
{
auto clone = std::make_unique<DeclareConstStatement>();
clone->constIndex = node.constIndex;
clone->name = node.name;
clone->type = node.type;
clone->expression = CloneExpression(node.expression);
return clone;
}
StatementPtr AstCloner::Clone(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareExternalStatement>();

View File

@@ -844,7 +844,10 @@ namespace Nz::ShaderAst
if (!m_options.constantQueryCallback)
return AstCloner::Clone(node);
return ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
constant->cachedExpressionType = GetExpressionType(constant->value);
return constant;
}
ExpressionPtr AstOptimizer::Clone(UnaryExpression& node)

View File

@@ -122,6 +122,12 @@ namespace Nz::ShaderAst
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareConstStatement& node)
{
if (node.expression)
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareExternalStatement& /*node*/)
{
/* Nothing to do */

View File

@@ -208,6 +208,7 @@ namespace Nz::ShaderAst
}
Node(node.elseStatement);
Value(node.isConst);
}
void AstSerializerBase::Serialize(ConditionalStatement& node)
@@ -232,6 +233,14 @@ namespace Nz::ShaderAst
}
}
void AstSerializerBase::Serialize(DeclareConstStatement& node)
{
OptVal(node.constIndex);
Value(node.name);
Type(node.type);
Node(node.expression);
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);

View File

@@ -577,6 +577,18 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(ConstantIndexExpression& node)
{
if (node.constantId >= m_context->constantValues.size())
throw AstError{ "invalid constant index " + std::to_string(node.constantId) };
// Replace by constant value
auto constant = ShaderBuilder::Constant(m_context->constantValues[node.constantId]);
constant->cachedExpressionType = GetExpressionType(constant->value);
return constant;
}
ExpressionPtr SanitizeVisitor::Clone(IdentifierExpression& node)
{
assert(m_context);
@@ -712,11 +724,46 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr SanitizeVisitor::Clone(VariableExpression& node)
{
if (node.variableId >= m_context->variableTypes.size())
throw AstError{ "invalid constant index " + std::to_string(node.variableId) };
node.cachedExpressionType = m_context->variableTypes[node.variableId];
return AstCloner::Clone(node);
}
StatementPtr SanitizeVisitor::Clone(BranchStatement& node)
{
if (node.isConst)
{
// Evaluate every condition at compilation and select the right statement
for (auto& cond : node.condStatements)
{
MandatoryExpr(cond.condition);
ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*cond.condition));
if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean })
throw AstError{ "expected a boolean value" };
if (std::get<bool>(conditionValue))
return AstCloner::Clone(*cond.statement);
}
// Every condition failed, fallback to else if any
if (node.elseStatement)
return AstCloner::Clone(*node.elseStatement);
else
return ShaderBuilder::NoOp();
}
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
if (!m_context->currentFunction)
throw AstError{ "non-const branching statements can only exist inside a function" };
for (auto& cond : node.condStatements)
{
PushScope();
@@ -758,6 +805,31 @@ namespace Nz::ShaderAst
return ShaderBuilder::NoOp();
}
StatementPtr SanitizeVisitor::Clone(DeclareConstStatement& node)
{
auto clone = static_unique_pointer_cast<DeclareConstStatement>(AstCloner::Clone(node));
if (!clone->expression)
throw AstError{ "const variables must have an expression" };
clone->expression = Optimize(*clone->expression);
if (clone->expression->GetType() != NodeType::ConstantExpression)
throw AstError{ "const variable must have constant expressions " };
const ConstantValue& value = static_cast<ConstantExpression&>(*clone->expression).value;
ExpressionType expressionType = ResolveType(GetExpressionType(value));
if (!IsNoType(clone->type) && ResolveType(clone->type) != expressionType)
throw AstError{ "constant expression doesn't match type" };
clone->type = expressionType;
clone->constIndex = RegisterConstant(clone->name, value);
return clone;
}
StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node)
{
assert(m_context);
@@ -815,6 +887,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "a function cannot be defined inside another function" };
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->name = node.name;
clone->parameters = node.parameters;
@@ -908,6 +983,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareOptionStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "options must be declared outside of functions" };
auto clone = static_unique_pointer_cast<DeclareOptionStatement>(AstCloner::Clone(node));
clone->optType = ResolveType(clone->optType);
@@ -926,6 +1004,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node)
{
if (m_context->currentFunction)
throw AstError{ "structs must be declared outside of functions" };
auto clone = static_unique_pointer_cast<DeclareStructStatement>(AstCloner::Clone(node));
std::unordered_set<std::string> declaredMembers;
@@ -961,6 +1042,9 @@ namespace Nz::ShaderAst
StatementPtr SanitizeVisitor::Clone(DeclareVariableStatement& node)
{
if (!m_context->currentFunction)
throw AstError{ "global variables outside of external blocks are forbidden" };
auto clone = static_unique_pointer_cast<DeclareVariableStatement>(AstCloner::Clone(node));
if (IsNoType(clone->varType))
{
@@ -1092,6 +1176,17 @@ namespace Nz::ShaderAst
}
ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr)
{
// Run optimizer on constant value to hopefully retrieve a single constant value
ExpressionPtr optimizedExpr = Optimize(expr);
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
throw AstError{"expected a constant expression"};
return static_cast<ConstantExpression&>(*optimizedExpr).value;
}
template<typename T>
std::unique_ptr<T> SanitizeVisitor::Optimize(T& node)
{
AstOptimizer::Options optimizerOptions;
optimizerOptions.constantQueryCallback = [this](std::size_t constantId)
@@ -1103,11 +1198,7 @@ namespace Nz::ShaderAst
optimizerOptions.enabledOptions = m_context->options.enabledOptions;
// Run optimizer on constant value to hopefully retrieve a single constant value
ExpressionPtr optimizedExpr = Optimize(expr, optimizerOptions);
if (optimizedExpr->GetType() != NodeType::ConstantExpression)
throw AstError{"expected a constant expression"};
return static_cast<ConstantExpression&>(*optimizedExpr).value;
return static_unique_pointer_cast<T>(ShaderAst::Optimize(node, optimizerOptions));
}
std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl)

View File

@@ -737,6 +737,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
{
assert(!node.isConst);
bool first = true;
for (const auto& statement : node.condStatements)
{
@@ -850,6 +852,11 @@ namespace Nz
}, node.value);
}
void GlslWriter::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{
/* nothing to do */
}
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
assert(node.varIndex);
@@ -1033,9 +1040,7 @@ namespace Nz
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
Append(node.varType);
Append(" ");
Append(node.varName);
Append(node.varType, " ", node.varName);
if (node.initialExpression)
{
Append(" = ");

View File

@@ -818,6 +818,10 @@ namespace Nz
Append("min");
break;
case ShaderAst::IntrinsicType::Pow:
Append("pow");
break;
case ShaderAst::IntrinsicType::SampleTexture:
assert(!node.parameters.empty());
Visit(node.parameters.front(), true);

View File

@@ -40,6 +40,7 @@ namespace Nz::ShaderLang
ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "const", TokenType::Const },
{ "discard", TokenType::Discard },
{ "else", TokenType::Else },
{ "external", TokenType::External },

View File

@@ -118,6 +118,13 @@ namespace Nz::ShaderLang
const Token& nextToken = Peek();
switch (nextToken.type)
{
case TokenType::Const:
if (!attributes.empty())
throw UnexpectedToken{};
context.root->statements.push_back(ParseConstStatement());
break;
case TokenType::EndOfStream:
if (!attributes.empty())
throw UnexpectedToken{};
@@ -400,6 +407,28 @@ namespace Nz::ShaderLang
return attributes;
}
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue)
{
name = ParseIdentifierAsName();
if (Peek().type == TokenType::Colon)
{
Expect(Advance(), TokenType::Colon);
type = ParseType();
}
else
type = ShaderAst::NoType{};
if (IsNoType(type) || Peek().type == TokenType::Assign)
{
Expect(Advance(), TokenType::Assign);
initialValue = ParseExpression();
}
Expect(Advance(), TokenType::Semicolon);
}
ShaderAst::StatementPtr Parser::ParseBranchStatement()
{
std::unique_ptr<ShaderAst::BranchStatement> branch = std::make_unique<ShaderAst::BranchStatement>();
@@ -434,9 +463,41 @@ namespace Nz::ShaderLang
return branch;
}
ShaderAst::StatementPtr Parser::ParseConstStatement()
{
Expect(Advance(), TokenType::Const);
switch (Peek().type)
{
case TokenType::Identifier:
{
std::string constName;
ShaderAst::ExpressionType constType;
ShaderAst::ExpressionPtr initialValue;
ParseVariableDeclaration(constName, constType, initialValue);
RegisterVariable(constName);
return ShaderBuilder::DeclareConst(std::move(constName), std::move(constType), std::move(initialValue));
}
case TokenType::If:
{
auto branch = ParseBranchStatement();
static_cast<ShaderAst::BranchStatement&>(*branch).isConst = true;
return branch;
}
default:
throw UnexpectedToken{};
}
}
ShaderAst::StatementPtr Parser::ParseDiscardStatement()
{
Expect(Advance(), TokenType::Discard);
Expect(Advance(), TokenType::Semicolon);
return ShaderBuilder::Discard();
}
@@ -728,6 +789,8 @@ namespace Nz::ShaderLang
if (Peek().type != TokenType::Semicolon)
expr = ParseExpression();
Expect(Advance(), TokenType::Semicolon);
return ShaderBuilder::Return(std::move(expr));
}
@@ -738,14 +801,16 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr statement;
switch (token.type)
{
case TokenType::Const:
statement = ParseConstStatement();
break;
case TokenType::Discard:
statement = ParseDiscardStatement();
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::Let:
statement = ParseVariableDeclaration();
Expect(Advance(), TokenType::Semicolon);
break;
case TokenType::Identifier:
@@ -759,7 +824,6 @@ namespace Nz::ShaderLang
case TokenType::Return:
statement = ParseReturnStatement();
Expect(Advance(), TokenType::Semicolon);
break;
default:
@@ -809,23 +873,12 @@ namespace Nz::ShaderLang
{
Expect(Advance(), TokenType::Let);
std::string variableName = ParseIdentifierAsName();
RegisterVariable(variableName);
ShaderAst::ExpressionType variableType = ShaderAst::NoType{};
if (Peek().type == TokenType::Colon)
{
Expect(Advance(), TokenType::Colon);
variableType = ParseType();
}
std::string variableName;
ShaderAst::ExpressionType variableType;
ShaderAst::ExpressionPtr expression;
if (IsNoType(variableType) || Peek().type == TokenType::Assign)
{
Expect(Advance(), TokenType::Assign);
expression = ParseExpression();
}
ParseVariableDeclaration(variableName, variableType, expression);
RegisterVariable(variableName);
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}

View File

@@ -584,6 +584,11 @@ namespace Nz
}, node.value);
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareConstStatement& /*node*/)
{
/* nothing to do */
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareExternalStatement& node)
{
assert(node.varIndex);

View File

@@ -86,6 +86,23 @@ namespace Nz
m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value()));
}
void Visit(ShaderAst::CallFunctionExpression& node) override
{
AstRecursiveVisitor::Visit(node);
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
auto& funcCall = func.funcCalls.emplace_back();
funcCall.firstVarIndex = func.variables.size();
for (const auto& parameter : node.parameters)
{
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
}
}
void Visit(ShaderAst::ConditionalExpression& node) override
{
throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?");
@@ -126,23 +143,6 @@ namespace Nz
}
}
void Visit(ShaderAst::CallFunctionExpression& node) override
{
AstRecursiveVisitor::Visit(node);
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
auto& funcCall = func.funcCalls.emplace_back();
funcCall.firstVarIndex = func.variables.size();
for (const auto& parameter : node.parameters)
{
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
}
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
std::optional<ShaderStageType> entryPointType;

View File

@@ -68,8 +68,11 @@ void CodeOutputWidget::Refresh()
{
shaderAst = Nz::ShaderAst::Sanitize(*shaderAst);
Nz::ShaderAst::AstOptimizer::Options optimOptions;
optimOptions.enabledOptions = enabledConditions;
Nz::ShaderAst::AstOptimizer optimiser;
shaderAst = optimiser.Optimise(*shaderAst, enabledConditions);
shaderAst = optimiser.Optimise(*shaderAst, optimOptions);
}
Nz::ShaderWriter::States states;