Shader: Handle type as expressions

This commit is contained in:
Jérôme Leclercq
2022-02-08 17:03:34 +01:00
parent 5ce8120a0c
commit 402e16bd2b
53 changed files with 1746 additions and 1141 deletions

View File

@@ -36,6 +36,20 @@ namespace Nz::ShaderAst
return PopStatement();
}
ExpressionValue<ExpressionType> AstCloner::CloneType(const ExpressionValue<ExpressionType>& exprType)
{
if (!exprType.HasValue())
return {};
if (exprType.IsExpression())
return CloneExpression(exprType.GetExpression());
else
{
assert(exprType.IsResultingValue());
return exprType.GetResultingValue();
}
}
StatementPtr AstCloner::Clone(BranchStatement& node)
{
auto clone = std::make_unique<BranchStatement>();
@@ -68,7 +82,7 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<DeclareConstStatement>();
clone->constIndex = node.constIndex;
clone->name = node.name;
clone->type = node.type;
clone->type = Clone(node.type);
clone->expression = CloneExpression(node.expression);
return clone;
@@ -86,7 +100,7 @@ namespace Nz::ShaderAst
{
auto& cloneVar = clone->externalVars.emplace_back();
cloneVar.name = var.name;
cloneVar.type = var.type;
cloneVar.type = Clone(var.type);
cloneVar.bindingIndex = Clone(var.bindingIndex);
cloneVar.bindingSet = Clone(var.bindingSet);
}
@@ -102,10 +116,17 @@ namespace Nz::ShaderAst
clone->entryStage = Clone(node.entryStage);
clone->funcIndex = node.funcIndex;
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->returnType = Clone(node.returnType);
clone->varIndex = node.varIndex;
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
{
auto& cloneParam = clone->parameters.emplace_back();
cloneParam.name = parameter.name;
cloneParam.type = Clone(parameter.type);
}
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
@@ -119,7 +140,7 @@ namespace Nz::ShaderAst
clone->defaultValue = CloneExpression(node.defaultValue);
clone->optIndex = node.optIndex;
clone->optName = node.optName;
clone->optType = node.optType;
clone->optType = Clone(node.optType);
return clone;
}
@@ -137,7 +158,7 @@ namespace Nz::ShaderAst
{
auto& cloneMember = clone->description.members.emplace_back();
cloneMember.name = member.name;
cloneMember.type = member.type;
cloneMember.type = Clone(member.type);
cloneMember.builtin = Clone(member.builtin);
cloneMember.cond = Clone(member.cond);
cloneMember.locationIndex = Clone(member.locationIndex);
@@ -151,7 +172,7 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<DeclareVariableStatement>();
clone->varIndex = node.varIndex;
clone->varName = node.varName;
clone->varType = node.varType;
clone->varType = Clone(node.varType);
clone->initialExpression = CloneExpression(node.initialExpression);
return clone;
@@ -217,6 +238,14 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(ScopedStatement& node)
{
auto clone = std::make_unique<ScopedStatement>();
clone->statement = CloneStatement(node.statement);
return clone;
}
StatementPtr AstCloner::Clone(WhileStatement& node)
{
auto clone = std::make_unique<WhileStatement>();
@@ -279,7 +308,7 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(CallFunctionExpression& node)
{
auto clone = std::make_unique<CallFunctionExpression>();
clone->targetFunction = node.targetFunction;
clone->targetFunction = CloneExpression(node.targetFunction);
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
@@ -309,7 +338,7 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(CastExpression& node)
{
auto clone = std::make_unique<CastExpression>();
clone->targetType = node.targetType;
clone->targetType = Clone(node.targetType);
std::size_t expressionCount = 0;
for (auto& expr : node.expressions)

View File

@@ -7,7 +7,7 @@
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_EXPRESSION(Node) void ExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
#define NAZARA_SHADERAST_EXPRESSION(Node) void AstExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}

View File

@@ -818,13 +818,13 @@ namespace Nz::ShaderAst
}
ExpressionPtr optimized;
if (IsPrimitiveType(node.targetType))
if (IsPrimitiveType(node.targetType.GetResultingValue()))
{
if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression)
{
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expressions.front());
switch (std::get<PrimitiveType>(node.targetType))
switch (std::get<PrimitiveType>(node.targetType.GetResultingValue()))
{
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(constantExpr); break;
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(constantExpr); break;
@@ -833,9 +833,9 @@ namespace Nz::ShaderAst
}
}
}
else if (IsVectorType(node.targetType))
else if (IsVectorType(node.targetType.GetResultingValue()))
{
const auto& vecType = std::get<VectorType>(node.targetType);
const auto& vecType = std::get<VectorType>(node.targetType.GetResultingValue());
// Decompose vector into values (cast(vec3, float) => cast(float, float, float, float))
std::vector<ConstantValue> constantValues;
@@ -916,7 +916,7 @@ namespace Nz::ShaderAst
if (optimized)
return optimized;
auto cast = ShaderBuilder::Cast(node.targetType, std::move(expressions));
auto cast = ShaderBuilder::Cast(node.targetType.GetResultingValue(), std::move(expressions));
cast->cachedExpressionType = node.cachedExpressionType;
return cast;
@@ -946,7 +946,7 @@ namespace Nz::ShaderAst
if (statements.empty())
{
// First condition is true, dismiss the branch
return AstCloner::Clone(*condStatement.statement);
return Unscope(AstCloner::Clone(*condStatement.statement));
}
else
{
@@ -967,7 +967,7 @@ namespace Nz::ShaderAst
{
// All conditions have been removed, replace by else statement or no-op
if (node.elseStatement)
return AstCloner::Clone(*node.elseStatement);
return Unscope(AstCloner::Clone(*node.elseStatement));
else
return ShaderBuilder::NoOp();
}
@@ -1243,4 +1243,15 @@ namespace Nz::ShaderAst
return optimized;
}
StatementPtr AstOptimizer::Unscope(StatementPtr node)
{
assert(node);
if (node->GetType() == NodeType::ScopedStatement)
return std::move(static_cast<ScopedStatement&>(*node).statement);
else
return node;
}
}

View File

@@ -202,6 +202,12 @@ namespace Nz::ShaderAst
node.returnExpr->Visit(*this);
}
void AstRecursiveVisitor::Visit(ScopedStatement& node)
{
if (node.statement)
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(WhileStatement& node)
{
if (node.condition)

View File

@@ -67,27 +67,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(CallFunctionExpression& node)
{
UInt32 typeIndex;
if (IsWriting())
typeIndex = UInt32(node.targetFunction.index());
Value(typeIndex);
// Waiting for template lambda in C++20
auto SerializeValue = [&](auto dummyType)
{
using T = std::decay_t<decltype(dummyType)>;
auto& value = (IsWriting()) ? std::get<T>(node.targetFunction) : node.targetFunction.emplace<T>();
Value(value);
};
static_assert(std::variant_size_v<decltype(node.targetFunction)> == 2);
switch (typeIndex)
{
case 0: SerializeValue(std::string()); break;
case 1: SerializeValue(std::size_t()); break;
}
Node(node.targetFunction);
Container(node.parameters);
for (auto& param : node.parameters)
@@ -106,7 +86,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(CastExpression& node)
{
Type(node.targetType);
ExprValue(node.targetType);
for (auto& expr : node.expressions)
Node(expr);
}
@@ -215,15 +195,15 @@ namespace Nz::ShaderAst
{
OptVal(node.varIndex);
Attribute(node.bindingSet);
ExprValue(node.bindingSet);
Container(node.externalVars);
for (auto& extVar : node.externalVars)
{
Value(extVar.name);
Type(extVar.type);
Attribute(extVar.bindingIndex);
Attribute(extVar.bindingSet);
ExprValue(extVar.type);
ExprValue(extVar.bindingIndex);
ExprValue(extVar.bindingSet);
}
}
@@ -231,17 +211,17 @@ namespace Nz::ShaderAst
{
OptVal(node.constIndex);
Value(node.name);
Type(node.type);
ExprValue(node.type);
Node(node.expression);
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);
Type(node.returnType);
Attribute(node.depthWrite);
Attribute(node.earlyFragmentTests);
Attribute(node.entryStage);
ExprValue(node.returnType);
ExprValue(node.depthWrite);
ExprValue(node.earlyFragmentTests);
ExprValue(node.entryStage);
OptVal(node.funcIndex);
OptVal(node.varIndex);
@@ -249,7 +229,7 @@ namespace Nz::ShaderAst
for (auto& parameter : node.parameters)
{
Value(parameter.name);
Type(parameter.type);
ExprValue(parameter.type);
}
Container(node.statements);
@@ -261,7 +241,7 @@ namespace Nz::ShaderAst
{
OptVal(node.optIndex);
Value(node.optName);
Type(node.optType);
ExprValue(node.optType);
Node(node.defaultValue);
}
@@ -270,16 +250,16 @@ namespace Nz::ShaderAst
OptVal(node.structIndex);
Value(node.description.name);
Attribute(node.description.layout);
ExprValue(node.description.layout);
Container(node.description.members);
for (auto& member : node.description.members)
{
Value(member.name);
Type(member.type);
Attribute(member.builtin);
Attribute(member.cond);
Attribute(member.locationIndex);
ExprValue(member.type);
ExprValue(member.builtin);
ExprValue(member.cond);
ExprValue(member.locationIndex);
}
}
@@ -287,7 +267,7 @@ namespace Nz::ShaderAst
{
OptVal(node.varIndex);
Value(node.varName);
Type(node.varType);
ExprValue(node.varType);
Node(node.initialExpression);
}
@@ -303,7 +283,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(ForStatement& node)
{
Attribute(node.unroll);
ExprValue(node.unroll);
Value(node.varName);
Node(node.fromExpr);
Node(node.toExpr);
@@ -313,7 +293,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(ForEachStatement& node)
{
Attribute(node.unroll);
ExprValue(node.unroll);
Value(node.varName);
Node(node.expression);
Node(node.statement);
@@ -336,9 +316,14 @@ namespace Nz::ShaderAst
Node(node.returnExpr);
}
void AstSerializerBase::Serialize(ScopedStatement& node)
{
Node(node.statement);
}
void AstSerializerBase::Serialize(WhileStatement& node)
{
Attribute(node.unroll);
ExprValue(node.unroll);
Node(node.condition);
Node(node.body);
}
@@ -392,7 +377,7 @@ namespace Nz::ShaderAst
else if constexpr (std::is_same_v<T, PrimitiveType>)
{
m_stream << UInt8(1);
m_stream << UInt32(arg);
Enum(arg);
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
@@ -402,38 +387,59 @@ namespace Nz::ShaderAst
else if constexpr (std::is_same_v<T, MatrixType>)
{
m_stream << UInt8(3);
m_stream << UInt32(arg.columnCount);
m_stream << UInt32(arg.rowCount);
m_stream << UInt32(arg.type);
SizeT(arg.columnCount);
SizeT(arg.rowCount);
Enum(arg.type);
}
else if constexpr (std::is_same_v<T, SamplerType>)
{
m_stream << UInt8(4);
m_stream << UInt32(arg.dim);
m_stream << UInt32(arg.sampledType);
Enum(arg.dim);
Enum(arg.sampledType);
}
else if constexpr (std::is_same_v<T, StructType>)
{
m_stream << UInt8(5);
m_stream << UInt32(arg.structIndex);
SizeT(arg.structIndex);
}
else if constexpr (std::is_same_v<T, UniformType>)
{
m_stream << UInt8(6);
m_stream << std::get<IdentifierType>(arg.containedType).name;
SizeT(arg.containedType.structIndex);
}
else if constexpr (std::is_same_v<T, VectorType>)
{
m_stream << UInt8(7);
m_stream << UInt32(arg.componentCount);
m_stream << UInt32(arg.type);
SizeT(arg.componentCount);
Enum(arg.type);
}
else if constexpr (std::is_same_v<T, ArrayType>)
{
m_stream << UInt8(8);
Attribute(arg.length);
Value(arg.length);
Type(arg.containedType->type);
}
else if constexpr (std::is_same_v<T, ShaderAst::Type>)
{
m_stream << UInt8(9);
SizeT(arg.typeIndex);
}
else if constexpr (std::is_same_v<T, ShaderAst::FunctionType>)
{
m_stream << UInt8(10);
SizeT(arg.funcIndex);
}
else if constexpr (std::is_same_v<T, ShaderAst::IntrinsicFunctionType>)
{
m_stream << UInt8(11);
Enum(arg.intrinsic);
}
else if constexpr (std::is_same_v<T, ShaderAst::MethodType>)
{
m_stream << UInt8(12);
Type(arg.objectType->type);
SizeT(arg.methodIndex);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
@@ -618,10 +624,10 @@ namespace Nz::ShaderAst
case 3: //< MatrixType
{
UInt32 columnCount, rowCount;
std::size_t columnCount, rowCount;
PrimitiveType primitiveType;
Value(columnCount);
Value(rowCount);
SizeT(columnCount);
SizeT(rowCount);
Enum(primitiveType);
type = MatrixType {
@@ -659,12 +665,12 @@ namespace Nz::ShaderAst
case 6: //< UniformType
{
std::string containedType;
Value(containedType);
std::size_t structIndex;
SizeT(structIndex);
type = UniformType {
IdentifierType {
containedType
StructType {
structIndex
}
};
break;
@@ -672,9 +678,9 @@ namespace Nz::ShaderAst
case 7: //< VectorType
{
UInt32 componentCount;
std::size_t componentCount;
PrimitiveType componentType;
Value(componentCount);
SizeT(componentCount);
Enum(componentType);
type = VectorType{
@@ -686,13 +692,13 @@ namespace Nz::ShaderAst
case 8: //< ArrayType
{
AttributeValue<UInt32> length;
UInt32 length;
ExpressionType containedType;
Attribute(length);
Value(length);
Type(containedType);
ArrayType arrayType;
arrayType.length = std::move(length);
arrayType.length = length;
arrayType.containedType = std::make_unique<ContainedType>();
arrayType.containedType->type = std::move(containedType);
@@ -700,6 +706,52 @@ namespace Nz::ShaderAst
break;
}
case 9: //< Type
{
std::size_t containedTypeIndex;
SizeT(containedTypeIndex);
type = ShaderAst::Type{
containedTypeIndex
};
}
case 10: //< FunctionType
{
std::size_t funcIndex;
SizeT(funcIndex);
type = FunctionType {
funcIndex
};
}
case 11: //< IntrinsicFunctionType
{
IntrinsicType intrinsicType;
Enum(intrinsicType);
type = IntrinsicFunctionType {
intrinsicType
};
}
case 12: //< MethodType
{
ExpressionType objectType;
Type(objectType);
std::size_t methodIndex;
SizeT(methodIndex);
MethodType methodType;
methodType.objectType = std::make_unique<ContainedType>();
methodType.objectType->type = std::move(objectType);
methodType.methodIndex = methodIndex;
type = std::move(methodType);
}
default:
break;
}

View File

@@ -7,7 +7,7 @@
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_STATEMENT(Node) void StatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
#define NAZARA_SHADERAST_STATEMENT(Node) void AstStatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}

View File

@@ -9,11 +9,11 @@
namespace Nz::ShaderAst
{
ArrayType::ArrayType(const ArrayType& array)
ArrayType::ArrayType(const ArrayType& array) :
length(array.length)
{
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(array.length);
}
ArrayType& ArrayType::operator=(const ArrayType& array)
@@ -21,7 +21,7 @@ namespace Nz::ShaderAst
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(array.length);
length = array.length;
return *this;
}
@@ -34,9 +34,34 @@ namespace Nz::ShaderAst
if (containedType->type != rhs.containedType->type)
return false;
if (!Compare(length, rhs.length))
if (length != rhs.length)
return false;
return true;
}
MethodType::MethodType(const MethodType& methodType) :
methodIndex(methodType.methodIndex)
{
assert(methodType.objectType);
objectType = std::make_unique<ContainedType>(*methodType.objectType);
}
MethodType& MethodType::operator=(const MethodType& methodType)
{
assert(methodType.objectType);
methodIndex = methodType.methodIndex;
objectType = std::make_unique<ContainedType>(*methodType.objectType);
return *this;
}
bool MethodType::operator==(const MethodType& rhs) const
{
assert(objectType);
assert(rhs.objectType);
return objectType->type == rhs.objectType->type && methodIndex == rhs.methodIndex;
}
}

File diff suppressed because it is too large Load Diff