Shader: Handle type as expressions

This commit is contained in:
Jérôme Leclercq
2022-02-08 17:03:34 +01:00
parent 5ce8120a0c
commit 402e16bd2b
53 changed files with 1746 additions and 1141 deletions

View File

@@ -36,6 +36,20 @@ namespace Nz::ShaderAst
return PopStatement();
}
ExpressionValue<ExpressionType> AstCloner::CloneType(const ExpressionValue<ExpressionType>& exprType)
{
if (!exprType.HasValue())
return {};
if (exprType.IsExpression())
return CloneExpression(exprType.GetExpression());
else
{
assert(exprType.IsResultingValue());
return exprType.GetResultingValue();
}
}
StatementPtr AstCloner::Clone(BranchStatement& node)
{
auto clone = std::make_unique<BranchStatement>();
@@ -68,7 +82,7 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<DeclareConstStatement>();
clone->constIndex = node.constIndex;
clone->name = node.name;
clone->type = node.type;
clone->type = Clone(node.type);
clone->expression = CloneExpression(node.expression);
return clone;
@@ -86,7 +100,7 @@ namespace Nz::ShaderAst
{
auto& cloneVar = clone->externalVars.emplace_back();
cloneVar.name = var.name;
cloneVar.type = var.type;
cloneVar.type = Clone(var.type);
cloneVar.bindingIndex = Clone(var.bindingIndex);
cloneVar.bindingSet = Clone(var.bindingSet);
}
@@ -102,10 +116,17 @@ namespace Nz::ShaderAst
clone->entryStage = Clone(node.entryStage);
clone->funcIndex = node.funcIndex;
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->returnType = Clone(node.returnType);
clone->varIndex = node.varIndex;
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
{
auto& cloneParam = clone->parameters.emplace_back();
cloneParam.name = parameter.name;
cloneParam.type = Clone(parameter.type);
}
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
@@ -119,7 +140,7 @@ namespace Nz::ShaderAst
clone->defaultValue = CloneExpression(node.defaultValue);
clone->optIndex = node.optIndex;
clone->optName = node.optName;
clone->optType = node.optType;
clone->optType = Clone(node.optType);
return clone;
}
@@ -137,7 +158,7 @@ namespace Nz::ShaderAst
{
auto& cloneMember = clone->description.members.emplace_back();
cloneMember.name = member.name;
cloneMember.type = member.type;
cloneMember.type = Clone(member.type);
cloneMember.builtin = Clone(member.builtin);
cloneMember.cond = Clone(member.cond);
cloneMember.locationIndex = Clone(member.locationIndex);
@@ -151,7 +172,7 @@ namespace Nz::ShaderAst
auto clone = std::make_unique<DeclareVariableStatement>();
clone->varIndex = node.varIndex;
clone->varName = node.varName;
clone->varType = node.varType;
clone->varType = Clone(node.varType);
clone->initialExpression = CloneExpression(node.initialExpression);
return clone;
@@ -217,6 +238,14 @@ namespace Nz::ShaderAst
return clone;
}
StatementPtr AstCloner::Clone(ScopedStatement& node)
{
auto clone = std::make_unique<ScopedStatement>();
clone->statement = CloneStatement(node.statement);
return clone;
}
StatementPtr AstCloner::Clone(WhileStatement& node)
{
auto clone = std::make_unique<WhileStatement>();
@@ -279,7 +308,7 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(CallFunctionExpression& node)
{
auto clone = std::make_unique<CallFunctionExpression>();
clone->targetFunction = node.targetFunction;
clone->targetFunction = CloneExpression(node.targetFunction);
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
@@ -309,7 +338,7 @@ namespace Nz::ShaderAst
ExpressionPtr AstCloner::Clone(CastExpression& node)
{
auto clone = std::make_unique<CastExpression>();
clone->targetType = node.targetType;
clone->targetType = Clone(node.targetType);
std::size_t expressionCount = 0;
for (auto& expr : node.expressions)

View File

@@ -7,7 +7,7 @@
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_EXPRESSION(Node) void ExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
#define NAZARA_SHADERAST_EXPRESSION(Node) void AstExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}

View File

@@ -818,13 +818,13 @@ namespace Nz::ShaderAst
}
ExpressionPtr optimized;
if (IsPrimitiveType(node.targetType))
if (IsPrimitiveType(node.targetType.GetResultingValue()))
{
if (expressionCount == 1 && expressions.front()->GetType() == NodeType::ConstantValueExpression)
{
const ConstantValueExpression& constantExpr = static_cast<const ConstantValueExpression&>(*expressions.front());
switch (std::get<PrimitiveType>(node.targetType))
switch (std::get<PrimitiveType>(node.targetType.GetResultingValue()))
{
case PrimitiveType::Boolean: optimized = PropagateSingleValueCast<bool>(constantExpr); break;
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(constantExpr); break;
@@ -833,9 +833,9 @@ namespace Nz::ShaderAst
}
}
}
else if (IsVectorType(node.targetType))
else if (IsVectorType(node.targetType.GetResultingValue()))
{
const auto& vecType = std::get<VectorType>(node.targetType);
const auto& vecType = std::get<VectorType>(node.targetType.GetResultingValue());
// Decompose vector into values (cast(vec3, float) => cast(float, float, float, float))
std::vector<ConstantValue> constantValues;
@@ -916,7 +916,7 @@ namespace Nz::ShaderAst
if (optimized)
return optimized;
auto cast = ShaderBuilder::Cast(node.targetType, std::move(expressions));
auto cast = ShaderBuilder::Cast(node.targetType.GetResultingValue(), std::move(expressions));
cast->cachedExpressionType = node.cachedExpressionType;
return cast;
@@ -946,7 +946,7 @@ namespace Nz::ShaderAst
if (statements.empty())
{
// First condition is true, dismiss the branch
return AstCloner::Clone(*condStatement.statement);
return Unscope(AstCloner::Clone(*condStatement.statement));
}
else
{
@@ -967,7 +967,7 @@ namespace Nz::ShaderAst
{
// All conditions have been removed, replace by else statement or no-op
if (node.elseStatement)
return AstCloner::Clone(*node.elseStatement);
return Unscope(AstCloner::Clone(*node.elseStatement));
else
return ShaderBuilder::NoOp();
}
@@ -1243,4 +1243,15 @@ namespace Nz::ShaderAst
return optimized;
}
StatementPtr AstOptimizer::Unscope(StatementPtr node)
{
assert(node);
if (node->GetType() == NodeType::ScopedStatement)
return std::move(static_cast<ScopedStatement&>(*node).statement);
else
return node;
}
}

View File

@@ -202,6 +202,12 @@ namespace Nz::ShaderAst
node.returnExpr->Visit(*this);
}
void AstRecursiveVisitor::Visit(ScopedStatement& node)
{
if (node.statement)
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(WhileStatement& node)
{
if (node.condition)

View File

@@ -67,27 +67,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(CallFunctionExpression& node)
{
UInt32 typeIndex;
if (IsWriting())
typeIndex = UInt32(node.targetFunction.index());
Value(typeIndex);
// Waiting for template lambda in C++20
auto SerializeValue = [&](auto dummyType)
{
using T = std::decay_t<decltype(dummyType)>;
auto& value = (IsWriting()) ? std::get<T>(node.targetFunction) : node.targetFunction.emplace<T>();
Value(value);
};
static_assert(std::variant_size_v<decltype(node.targetFunction)> == 2);
switch (typeIndex)
{
case 0: SerializeValue(std::string()); break;
case 1: SerializeValue(std::size_t()); break;
}
Node(node.targetFunction);
Container(node.parameters);
for (auto& param : node.parameters)
@@ -106,7 +86,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(CastExpression& node)
{
Type(node.targetType);
ExprValue(node.targetType);
for (auto& expr : node.expressions)
Node(expr);
}
@@ -215,15 +195,15 @@ namespace Nz::ShaderAst
{
OptVal(node.varIndex);
Attribute(node.bindingSet);
ExprValue(node.bindingSet);
Container(node.externalVars);
for (auto& extVar : node.externalVars)
{
Value(extVar.name);
Type(extVar.type);
Attribute(extVar.bindingIndex);
Attribute(extVar.bindingSet);
ExprValue(extVar.type);
ExprValue(extVar.bindingIndex);
ExprValue(extVar.bindingSet);
}
}
@@ -231,17 +211,17 @@ namespace Nz::ShaderAst
{
OptVal(node.constIndex);
Value(node.name);
Type(node.type);
ExprValue(node.type);
Node(node.expression);
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);
Type(node.returnType);
Attribute(node.depthWrite);
Attribute(node.earlyFragmentTests);
Attribute(node.entryStage);
ExprValue(node.returnType);
ExprValue(node.depthWrite);
ExprValue(node.earlyFragmentTests);
ExprValue(node.entryStage);
OptVal(node.funcIndex);
OptVal(node.varIndex);
@@ -249,7 +229,7 @@ namespace Nz::ShaderAst
for (auto& parameter : node.parameters)
{
Value(parameter.name);
Type(parameter.type);
ExprValue(parameter.type);
}
Container(node.statements);
@@ -261,7 +241,7 @@ namespace Nz::ShaderAst
{
OptVal(node.optIndex);
Value(node.optName);
Type(node.optType);
ExprValue(node.optType);
Node(node.defaultValue);
}
@@ -270,16 +250,16 @@ namespace Nz::ShaderAst
OptVal(node.structIndex);
Value(node.description.name);
Attribute(node.description.layout);
ExprValue(node.description.layout);
Container(node.description.members);
for (auto& member : node.description.members)
{
Value(member.name);
Type(member.type);
Attribute(member.builtin);
Attribute(member.cond);
Attribute(member.locationIndex);
ExprValue(member.type);
ExprValue(member.builtin);
ExprValue(member.cond);
ExprValue(member.locationIndex);
}
}
@@ -287,7 +267,7 @@ namespace Nz::ShaderAst
{
OptVal(node.varIndex);
Value(node.varName);
Type(node.varType);
ExprValue(node.varType);
Node(node.initialExpression);
}
@@ -303,7 +283,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(ForStatement& node)
{
Attribute(node.unroll);
ExprValue(node.unroll);
Value(node.varName);
Node(node.fromExpr);
Node(node.toExpr);
@@ -313,7 +293,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(ForEachStatement& node)
{
Attribute(node.unroll);
ExprValue(node.unroll);
Value(node.varName);
Node(node.expression);
Node(node.statement);
@@ -336,9 +316,14 @@ namespace Nz::ShaderAst
Node(node.returnExpr);
}
void AstSerializerBase::Serialize(ScopedStatement& node)
{
Node(node.statement);
}
void AstSerializerBase::Serialize(WhileStatement& node)
{
Attribute(node.unroll);
ExprValue(node.unroll);
Node(node.condition);
Node(node.body);
}
@@ -392,7 +377,7 @@ namespace Nz::ShaderAst
else if constexpr (std::is_same_v<T, PrimitiveType>)
{
m_stream << UInt8(1);
m_stream << UInt32(arg);
Enum(arg);
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
@@ -402,38 +387,59 @@ namespace Nz::ShaderAst
else if constexpr (std::is_same_v<T, MatrixType>)
{
m_stream << UInt8(3);
m_stream << UInt32(arg.columnCount);
m_stream << UInt32(arg.rowCount);
m_stream << UInt32(arg.type);
SizeT(arg.columnCount);
SizeT(arg.rowCount);
Enum(arg.type);
}
else if constexpr (std::is_same_v<T, SamplerType>)
{
m_stream << UInt8(4);
m_stream << UInt32(arg.dim);
m_stream << UInt32(arg.sampledType);
Enum(arg.dim);
Enum(arg.sampledType);
}
else if constexpr (std::is_same_v<T, StructType>)
{
m_stream << UInt8(5);
m_stream << UInt32(arg.structIndex);
SizeT(arg.structIndex);
}
else if constexpr (std::is_same_v<T, UniformType>)
{
m_stream << UInt8(6);
m_stream << std::get<IdentifierType>(arg.containedType).name;
SizeT(arg.containedType.structIndex);
}
else if constexpr (std::is_same_v<T, VectorType>)
{
m_stream << UInt8(7);
m_stream << UInt32(arg.componentCount);
m_stream << UInt32(arg.type);
SizeT(arg.componentCount);
Enum(arg.type);
}
else if constexpr (std::is_same_v<T, ArrayType>)
{
m_stream << UInt8(8);
Attribute(arg.length);
Value(arg.length);
Type(arg.containedType->type);
}
else if constexpr (std::is_same_v<T, ShaderAst::Type>)
{
m_stream << UInt8(9);
SizeT(arg.typeIndex);
}
else if constexpr (std::is_same_v<T, ShaderAst::FunctionType>)
{
m_stream << UInt8(10);
SizeT(arg.funcIndex);
}
else if constexpr (std::is_same_v<T, ShaderAst::IntrinsicFunctionType>)
{
m_stream << UInt8(11);
Enum(arg.intrinsic);
}
else if constexpr (std::is_same_v<T, ShaderAst::MethodType>)
{
m_stream << UInt8(12);
Type(arg.objectType->type);
SizeT(arg.methodIndex);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
@@ -618,10 +624,10 @@ namespace Nz::ShaderAst
case 3: //< MatrixType
{
UInt32 columnCount, rowCount;
std::size_t columnCount, rowCount;
PrimitiveType primitiveType;
Value(columnCount);
Value(rowCount);
SizeT(columnCount);
SizeT(rowCount);
Enum(primitiveType);
type = MatrixType {
@@ -659,12 +665,12 @@ namespace Nz::ShaderAst
case 6: //< UniformType
{
std::string containedType;
Value(containedType);
std::size_t structIndex;
SizeT(structIndex);
type = UniformType {
IdentifierType {
containedType
StructType {
structIndex
}
};
break;
@@ -672,9 +678,9 @@ namespace Nz::ShaderAst
case 7: //< VectorType
{
UInt32 componentCount;
std::size_t componentCount;
PrimitiveType componentType;
Value(componentCount);
SizeT(componentCount);
Enum(componentType);
type = VectorType{
@@ -686,13 +692,13 @@ namespace Nz::ShaderAst
case 8: //< ArrayType
{
AttributeValue<UInt32> length;
UInt32 length;
ExpressionType containedType;
Attribute(length);
Value(length);
Type(containedType);
ArrayType arrayType;
arrayType.length = std::move(length);
arrayType.length = length;
arrayType.containedType = std::make_unique<ContainedType>();
arrayType.containedType->type = std::move(containedType);
@@ -700,6 +706,52 @@ namespace Nz::ShaderAst
break;
}
case 9: //< Type
{
std::size_t containedTypeIndex;
SizeT(containedTypeIndex);
type = ShaderAst::Type{
containedTypeIndex
};
}
case 10: //< FunctionType
{
std::size_t funcIndex;
SizeT(funcIndex);
type = FunctionType {
funcIndex
};
}
case 11: //< IntrinsicFunctionType
{
IntrinsicType intrinsicType;
Enum(intrinsicType);
type = IntrinsicFunctionType {
intrinsicType
};
}
case 12: //< MethodType
{
ExpressionType objectType;
Type(objectType);
std::size_t methodIndex;
SizeT(methodIndex);
MethodType methodType;
methodType.objectType = std::make_unique<ContainedType>();
methodType.objectType->type = std::move(objectType);
methodType.methodIndex = methodIndex;
type = std::move(methodType);
}
default:
break;
}

View File

@@ -7,7 +7,7 @@
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_STATEMENT(Node) void StatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
#define NAZARA_SHADERAST_STATEMENT(Node) void AstStatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}

View File

@@ -9,11 +9,11 @@
namespace Nz::ShaderAst
{
ArrayType::ArrayType(const ArrayType& array)
ArrayType::ArrayType(const ArrayType& array) :
length(array.length)
{
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(array.length);
}
ArrayType& ArrayType::operator=(const ArrayType& array)
@@ -21,7 +21,7 @@ namespace Nz::ShaderAst
assert(array.containedType);
containedType = std::make_unique<ContainedType>(*array.containedType);
length = Clone(array.length);
length = array.length;
return *this;
}
@@ -34,9 +34,34 @@ namespace Nz::ShaderAst
if (containedType->type != rhs.containedType->type)
return false;
if (!Compare(length, rhs.length))
if (length != rhs.length)
return false;
return true;
}
MethodType::MethodType(const MethodType& methodType) :
methodIndex(methodType.methodIndex)
{
assert(methodType.objectType);
objectType = std::make_unique<ContainedType>(*methodType.objectType);
}
MethodType& MethodType::operator=(const MethodType& methodType)
{
assert(methodType.objectType);
methodIndex = methodType.methodIndex;
objectType = std::make_unique<ContainedType>(*methodType.objectType);
return *this;
}
bool MethodType::operator==(const MethodType& rhs) const
{
assert(objectType);
assert(rhs.objectType);
return objectType->type == rhs.objectType->type && methodIndex == rhs.methodIndex;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -43,7 +43,7 @@ namespace Nz
AstRecursiveVisitor::Visit(node);
assert(currentFunction);
currentFunction->calledFunctions.UnboundedSet(std::get<std::size_t>(node.targetFunction));
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex);
}
void Visit(ShaderAst::ConditionalExpression& /*node*/) override
@@ -227,6 +227,24 @@ namespace Nz
throw std::runtime_error("unexpected ArrayType");
}
void GlslWriter::Append(ShaderAst::BuiltinEntry builtin)
{
switch (builtin)
{
case ShaderAst::BuiltinEntry::FragCoord:
Append("gl_FragCoord");
break;
case ShaderAst::BuiltinEntry::FragDepth:
Append("gl_FragDepth");
break;
case ShaderAst::BuiltinEntry::VertexPosition:
Append("gl_Position");
break;
}
}
void GlslWriter::Append(const ShaderAst::ExpressionType& type)
{
std::visit([&](auto&& arg)
@@ -235,22 +253,14 @@ namespace Nz
}, type);
}
void GlslWriter::Append(ShaderAst::BuiltinEntry builtin)
void GlslWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
{
switch (builtin)
{
case ShaderAst::BuiltinEntry::FragCoord:
Append("gl_FragCoord");
break;
Append(type.GetResultingValue());
}
case ShaderAst::BuiltinEntry::FragDepth:
Append("gl_FragDepth");
break;
case ShaderAst::BuiltinEntry::VertexPosition:
Append("gl_Position");
break;
}
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
{
throw std::runtime_error("unexpected function type");
}
void GlslWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
@@ -258,6 +268,11 @@ namespace Nz
throw std::runtime_error("unexpected identifier type");
}
void GlslWriter::Append(const ShaderAst::IntrinsicFunctionType& /*intrinsicFunctionType*/)
{
throw std::runtime_error("unexpected intrinsic function type");
}
void GlslWriter::Append(const ShaderAst::MatrixType& matrixType)
{
if (matrixType.columnCount == matrixType.rowCount)
@@ -274,6 +289,11 @@ namespace Nz
}
}
void GlslWriter::Append(const ShaderAst::MethodType& methodType)
{
throw std::runtime_error("unexpected method type");
}
void GlslWriter::Append(ShaderAst::PrimitiveType type)
{
switch (type)
@@ -316,6 +336,11 @@ namespace Nz
Append(structDesc->name);
}
void GlslWriter::Append(const ShaderAst::Type& /*type*/)
{
throw std::runtime_error("unexpected Type");
}
void GlslWriter::Append(const ShaderAst::UniformType& /*uniformType*/)
{
throw std::runtime_error("unexpected UniformType");
@@ -386,7 +411,7 @@ namespace Nz
first = false;
AppendVariableDeclaration(parameter.type, parameter.name);
AppendVariableDeclaration(parameter.type.GetResultingValue(), parameter.name);
}
AppendLine((forward) ? ");" : ")");
}
@@ -506,13 +531,13 @@ namespace Nz
{
if (ShaderAst::IsArrayType(varType))
{
std::vector<const ShaderAst::AttributeValue<UInt32>*> lengths;
std::vector<UInt32> lengths;
const ShaderAst::ExpressionType* exprType = &varType;
while (ShaderAst::IsArrayType(*exprType))
{
const auto& arrayType = std::get<ShaderAst::ArrayType>(*exprType);
lengths.push_back(&arrayType.length);
lengths.push_back(arrayType.length);
exprType = &arrayType.containedType->type;
}
@@ -520,17 +545,8 @@ namespace Nz
assert(!ShaderAst::IsArrayType(*exprType));
Append(*exprType, " ", varName);
for (const auto* lengthAttribute : lengths)
{
Append("[");
if (lengthAttribute->IsResultingValue())
Append(lengthAttribute->GetResultingValue());
else
lengthAttribute->GetExpression()->Visit(*this);
Append("]");
}
for (UInt32 lengthAttribute : lengths)
Append("[", lengthAttribute, "]");
}
else
Append(varType, " ", varName);
@@ -582,8 +598,8 @@ namespace Nz
const std::string& varName = parameter.name;
RegisterVariable(*node.varIndex, varName);
assert(IsStructType(parameter.type));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
assert(IsStructType(parameter.type.GetResultingValue()));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex;
const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex);
AppendLine(structDesc->name, " ", varName, ";");
@@ -631,7 +647,7 @@ namespace Nz
Append("layout(location = ");
Append(member.locationIndex.GetResultingValue());
Append(") ", keyword, " ");
AppendVariableDeclaration(member.type, targetPrefix + member.name);
AppendVariableDeclaration(member.type.GetResultingValue(), targetPrefix + member.name);
AppendLine(";");
fields.push_back({
@@ -651,9 +667,9 @@ namespace Nz
{
assert(node.parameters.size() == 1);
auto& parameter = node.parameters.front();
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type.GetResultingValue()));
std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
std::size_t inputStructIndex = std::get<ShaderAst::StructType>(parameter.type.GetResultingValue()).structIndex;
inputStruct = Retrieve(m_currentState->structs, inputStructIndex);
AppendCommentSection("Inputs");
@@ -666,10 +682,10 @@ namespace Nz
AppendLine();
}
if (!IsNoType(node.returnType))
if (node.returnType.HasValue())
{
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType.GetResultingValue()));
std::size_t outputStructIndex = std::get<ShaderAst::StructType>(node.returnType.GetResultingValue()).structIndex;
const ShaderAst::StructDescription* outputStruct = Retrieve(m_currentState->structs, outputStructIndex);
@@ -690,6 +706,18 @@ namespace Nz
m_currentState->variableNames.emplace(varIndex, std::move(varName));
}
void GlslWriter::ScopeVisit(ShaderAst::Statement& node)
{
if (node.GetType() != ShaderAst::NodeType::ScopedStatement)
{
EnterScope();
node.Visit(*this);
LeaveScope(true);
}
else
node.Visit(*this);
}
void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
{
bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue);
@@ -722,12 +750,10 @@ namespace Nz
assert(!IsStructType(exprType));
// Array access
for (ShaderAst::ExpressionPtr& expr : node.indices)
{
Append("[");
Visit(expr);
Append("]");
}
assert(node.indices.size() == 1);
Append("[");
Visit(node.indices.front());
Append("]");
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
@@ -775,8 +801,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::CallFunctionExpression& node)
{
assert(std::holds_alternative<std::size_t>(node.targetFunction));
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, std::get<std::size_t>(node.targetFunction)).name;
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
const std::string& targetName = Retrieve(m_currentState->previsitor.functions, functionIndex).name;
Append(targetName, "(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
@@ -946,9 +972,7 @@ namespace Nz
statement.condition->Visit(*this);
AppendLine(")");
EnterScope();
statement.statement->Visit(*this);
LeaveScope();
ScopeVisit(*statement.statement);
first = false;
}
@@ -957,9 +981,7 @@ namespace Nz
{
AppendLine("else");
EnterScope();
node.elseStatement->Visit(*this);
LeaveScope();
ScopeVisit(*node.elseStatement);
}
}
@@ -976,13 +998,10 @@ namespace Nz
for (const auto& externalVar : node.externalVars)
{
bool isStd140 = false;
if (IsUniformType(externalVar.type))
if (IsUniformType(externalVar.type.GetResultingValue()))
{
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, structIndex);
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue());
ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
if (structInfo->layout.HasValue())
isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140;
}
@@ -1018,18 +1037,15 @@ namespace Nz
Append("uniform ");
if (IsUniformType(externalVar.type))
if (IsUniformType(externalVar.type.GetResultingValue()))
{
Append("_NzBinding_");
AppendLine(externalVar.name);
EnterScope();
{
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structDesc = Retrieve(m_currentState->structs, structIndex);
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type.GetResultingValue());
auto& structDesc = Retrieve(m_currentState->structs, uniform.containedType.structIndex);
bool first = true;
for (const auto& member : structDesc->members)
@@ -1042,7 +1058,7 @@ namespace Nz
first = false;
AppendVariableDeclaration(member.type, member.name);
AppendVariableDeclaration(member.type.GetResultingValue(), member.name);
Append(";");
}
}
@@ -1052,11 +1068,11 @@ namespace Nz
Append(externalVar.name);
}
else
AppendVariableDeclaration(externalVar.type, externalVar.name);
AppendVariableDeclaration(externalVar.type.GetResultingValue(), externalVar.name);
AppendLine(";");
if (IsUniformType(externalVar.type))
if (IsUniformType(externalVar.type.GetResultingValue()))
AppendLine();
RegisterVariable(varIndex++, externalVar.name);
@@ -1138,7 +1154,7 @@ namespace Nz
first = false;
AppendVariableDeclaration(member.type, member.name);
AppendVariableDeclaration(member.type.GetResultingValue(), member.name);
Append(";");
}
}
@@ -1151,7 +1167,7 @@ namespace Nz
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
AppendVariableDeclaration(node.varType, node.varName);
AppendVariableDeclaration(node.varType.GetResultingValue(), node.varName);
if (node.initialExpression)
{
Append(" = ");
@@ -1239,15 +1255,20 @@ namespace Nz
}
}
void GlslWriter::Visit(ShaderAst::ScopedStatement& node)
{
EnterScope();
node.statement->Visit(*this);
LeaveScope(true);
}
void GlslWriter::Visit(ShaderAst::WhileStatement& node)
{
Append("while (");
node.condition->Visit(*this);
AppendLine(")");
EnterScope();
node.body->Visit(*this);
LeaveScope();
ScopeVisit(*node.body);
}
bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader)

View File

@@ -29,63 +29,63 @@ namespace Nz
struct LangWriter::BindingAttribute
{
const ShaderAst::AttributeValue<UInt32>& bindingIndex;
const ShaderAst::ExpressionValue<UInt32>& bindingIndex;
inline bool HasValue() const { return bindingIndex.HasValue(); }
};
struct LangWriter::BuiltinAttribute
{
const ShaderAst::AttributeValue<ShaderAst::BuiltinEntry>& builtin;
const ShaderAst::ExpressionValue<ShaderAst::BuiltinEntry>& builtin;
inline bool HasValue() const { return builtin.HasValue(); }
};
struct LangWriter::DepthWriteAttribute
{
const ShaderAst::AttributeValue<ShaderAst::DepthWriteMode>& writeMode;
const ShaderAst::ExpressionValue<ShaderAst::DepthWriteMode>& writeMode;
inline bool HasValue() const { return writeMode.HasValue(); }
};
struct LangWriter::EarlyFragmentTestsAttribute
{
const ShaderAst::AttributeValue<bool>& earlyFragmentTests;
const ShaderAst::ExpressionValue<bool>& earlyFragmentTests;
inline bool HasValue() const { return earlyFragmentTests.HasValue(); }
};
struct LangWriter::EntryAttribute
{
const ShaderAst::AttributeValue<ShaderStageType>& stageType;
const ShaderAst::ExpressionValue<ShaderStageType>& stageType;
inline bool HasValue() const { return stageType.HasValue(); }
};
struct LangWriter::LayoutAttribute
{
const ShaderAst::AttributeValue<StructLayout>& layout;
const ShaderAst::ExpressionValue<StructLayout>& layout;
inline bool HasValue() const { return layout.HasValue(); }
};
struct LangWriter::LocationAttribute
{
const ShaderAst::AttributeValue<UInt32>& locationIndex;
const ShaderAst::ExpressionValue<UInt32>& locationIndex;
inline bool HasValue() const { return locationIndex.HasValue(); }
};
struct LangWriter::SetAttribute
{
const ShaderAst::AttributeValue<UInt32>& setIndex;
const ShaderAst::ExpressionValue<UInt32>& setIndex;
inline bool HasValue() const { return setIndex.HasValue(); }
};
struct LangWriter::UnrollAttribute
{
const ShaderAst::AttributeValue<ShaderAst::LoopUnroll>& unroll;
const ShaderAst::ExpressionValue<ShaderAst::LoopUnroll>& unroll;
inline bool HasValue() const { return unroll.HasValue(); }
};
@@ -126,14 +126,7 @@ namespace Nz
void LangWriter::Append(const ShaderAst::ArrayType& type)
{
Append("array[", type.containedType->type, ", ");
if (type.length.IsResultingValue())
Append(type.length.GetResultingValue());
else
type.length.GetExpression()->Visit(*this);
Append("]");
Append("array[", type.containedType->type, ", ", type.length, "]");
}
void LangWriter::Append(const ShaderAst::ExpressionType& type)
@@ -144,11 +137,26 @@ namespace Nz
}, type);
}
void LangWriter::Append(const ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type)
{
Append(type.GetResultingValue());
}
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
{
throw std::runtime_error("unexpected function type");
}
void LangWriter::Append(const ShaderAst::IdentifierType& /*identifierType*/)
{
throw std::runtime_error("unexpected identifier type");
}
void LangWriter::Append(const ShaderAst::IntrinsicFunctionType& /*functionType*/)
{
throw std::runtime_error("unexpected intrinsic function type");
}
void LangWriter::Append(const ShaderAst::MatrixType& matrixType)
{
if (matrixType.columnCount == matrixType.rowCount)
@@ -167,6 +175,11 @@ namespace Nz
Append("[", matrixType.type, "]");
}
void LangWriter::Append(const ShaderAst::MethodType& /*functionType*/)
{
throw std::runtime_error("unexpected method type");
}
void LangWriter::Append(ShaderAst::PrimitiveType type)
{
switch (type)
@@ -201,14 +214,14 @@ namespace Nz
Append(structDesc->name);
}
void LangWriter::Append(const ShaderAst::Type& /*type*/)
{
throw std::runtime_error("unexpected type?");
}
void LangWriter::Append(const ShaderAst::UniformType& uniformType)
{
Append("uniform[");
std::visit([&](auto&& arg)
{
Append(arg);
}, uniformType.containedType);
Append("]");
Append("uniform[", uniformType.containedType, "]");
}
void LangWriter::Append(const ShaderAst::VectorType& vecType)
@@ -411,6 +424,10 @@ namespace Nz
{
switch (entry.layout.GetResultingValue())
{
case StructLayout::Packed:
Append("packed");
break;
case StructLayout::Std140:
Append("std140");
break;
@@ -558,6 +575,18 @@ namespace Nz
m_currentState->variableNames.emplace(varIndex, std::move(varName));
}
void LangWriter::ScopeVisit(ShaderAst::Statement& node)
{
if (node.GetType() != ShaderAst::NodeType::ScopedStatement)
{
EnterScope();
node.Visit(*this);
LeaveScope(true);
}
else
node.Visit(*this);
}
void LangWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
{
bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue);
@@ -590,12 +619,19 @@ namespace Nz
assert(!IsStructType(exprType));
// Array access
Append("[");
bool first = true;
for (ShaderAst::ExpressionPtr& expr : node.indices)
{
Append("[");
if (!first)
Append(", ");
expr->Visit(*this);
Append("]");
first = false;
}
Append("]");
}
void LangWriter::Visit(ShaderAst::AssignExpression& node)
@@ -628,9 +664,7 @@ namespace Nz
statement.condition->Visit(*this);
AppendLine(")");
EnterScope();
statement.statement->Visit(*this);
LeaveScope();
ScopeVisit(*statement.statement);
first = false;
}
@@ -639,9 +673,7 @@ namespace Nz
{
AppendLine("else");
EnterScope();
node.elseStatement->Visit(*this);
LeaveScope();
ScopeVisit(*node.elseStatement);
}
}
@@ -800,8 +832,12 @@ namespace Nz
RegisterVariable(varIndex++, node.parameters[i].name);
}
Append(")");
if (!IsNoType(node.returnType))
Append(" -> ", node.returnType);
if (node.returnType.HasValue())
{
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
if (!IsNoType(returnType))
Append(" -> ", returnType);
}
AppendLine();
EnterScope();
@@ -896,9 +932,7 @@ namespace Nz
AppendLine();
EnterScope();
node.statement->Visit(*this);
LeaveScope();
ScopeVisit(*node.statement);
}
void LangWriter::Visit(ShaderAst::ForEachStatement& node)
@@ -911,9 +945,7 @@ namespace Nz
node.expression->Visit(*this);
AppendLine();
EnterScope();
node.statement->Visit(*this);
LeaveScope();
ScopeVisit(*node.statement);
}
void LangWriter::Visit(ShaderAst::IntrinsicExpression& node)
@@ -1001,6 +1033,13 @@ namespace Nz
Append("return;");
}
void LangWriter::Visit(ShaderAst::ScopedStatement& node)
{
EnterScope();
node.statement->Visit(*this);
LeaveScope(true);
}
void LangWriter::Visit(ShaderAst::SwizzleExpression& node)
{
Visit(node.expression, true);
@@ -1043,9 +1082,7 @@ namespace Nz
node.condition->Visit(*this);
AppendLine(")");
EnterScope();
node.body->Visit(*this);
LeaveScope();
ScopeVisit(*node.body);
}
void LangWriter::AppendHeader()

View File

@@ -20,13 +20,6 @@ namespace Nz::ShaderLang
{ "unchanged", ShaderAst::DepthWriteMode::Unchanged },
};
std::unordered_map<std::string, ShaderAst::PrimitiveType> s_identifierToBasicType = {
{ "bool", ShaderAst::PrimitiveType::Boolean },
{ "i32", ShaderAst::PrimitiveType::Int32 },
{ "f32", ShaderAst::PrimitiveType::Float32 },
{ "u32", ShaderAst::PrimitiveType::UInt32 }
};
std::unordered_map<std::string, ShaderAst::AttributeType> s_identifierToAttributeType = {
{ "binding", ShaderAst::AttributeType::Binding },
{ "builtin", ShaderAst::AttributeType::Builtin },
@@ -71,7 +64,7 @@ namespace Nz::ShaderLang
}
template<typename T>
void HandleUniqueAttribute(const std::string_view& attributeName, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param, bool requireValue = true)
void HandleUniqueAttribute(const std::string_view& attributeName, ShaderAst::ExpressionValue<T>& targetAttribute, ShaderAst::ExprValue::Param&& param, bool requireValue = true)
{
if (targetAttribute.HasValue())
throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" };
@@ -83,7 +76,7 @@ namespace Nz::ShaderLang
}
template<typename T>
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::AttributeValue<T>& targetAttribute, ShaderAst::Attribute::Param&& param, std::optional<T> defaultValue = {})
void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map<std::string, T>& map, ShaderAst::ExpressionValue<T>& targetAttribute, ShaderAst::ExprValue::Param&& param, std::optional<T> defaultValue = {})
{
if (targetAttribute.HasValue())
throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" };
@@ -123,9 +116,7 @@ namespace Nz::ShaderLang
m_context = &context;
std::vector<ShaderAst::Attribute> attributes;
EnterScope();
std::vector<ShaderAst::ExprValue> attributes;
bool reachedEndOfStream = false;
while (!reachedEndOfStream)
@@ -179,8 +170,6 @@ namespace Nz::ShaderLang
}
}
LeaveScope();
return std::move(context.root);
}
@@ -198,161 +187,6 @@ namespace Nz::ShaderLang
m_context->tokenIndex += count;
}
std::optional<ShaderAst::ExpressionType> Parser::DecodeType(const std::string& identifier)
{
if (auto it = s_identifierToBasicType.find(identifier); it != s_identifierToBasicType.end())
{
Consume();
return it->second;
}
//FIXME: Handle this better
if (identifier == "array")
{
Consume();
Expect(Advance(), TokenType::OpenSquareBracket); //< [
ShaderAst::ArrayType arrayType;
arrayType.containedType = std::make_unique<ShaderAst::ContainedType>();
arrayType.containedType->type = ParseType();
Expect(Advance(), TokenType::Comma); //< ,
arrayType.length = ParseExpression();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return arrayType;
}
else if (identifier == "mat4")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 4;
matrixType.rowCount = 4;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return matrixType;
}
else if (identifier == "mat3")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 3;
matrixType.rowCount = 3;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return matrixType;
}
else if (identifier == "mat2")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 2;
matrixType.rowCount = 2;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return matrixType;
}
else if (identifier == "sampler2D")
{
Consume();
ShaderAst::SamplerType samplerType;
samplerType.dim = ImageType::E2D;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return samplerType;
}
else if (identifier == "samplerCube")
{
Consume();
ShaderAst::SamplerType samplerType;
samplerType.dim = ImageType::Cubemap;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return samplerType;
}
else if (identifier == "uniform")
{
Consume();
ShaderAst::UniformType uniformType;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() };
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return uniformType;
}
else if (identifier == "vec2")
{
Consume();
ShaderAst::VectorType vectorType;
vectorType.componentCount = 2;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return vectorType;
}
else if (identifier == "vec3")
{
Consume();
ShaderAst::VectorType vectorType;
vectorType.componentCount = 3;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return vectorType;
}
else if (identifier == "vec4")
{
Consume();
ShaderAst::VectorType vectorType;
vectorType.componentCount = 4;
Expect(Advance(), TokenType::OpenSquareBracket); //< [
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::ClosingSquareBracket); //< ]
return vectorType;
}
else
return std::nullopt;
}
void Parser::EnterScope()
{
m_context->scopeSizes.push_back(m_context->identifiersInScope.size());
}
const Token& Parser::Expect(const Token& token, TokenType type)
{
if (token.type != type)
@@ -377,33 +211,15 @@ namespace Nz::ShaderLang
return token;
}
void Parser::LeaveScope()
{
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.resize(m_context->scopeSizes.back());
m_context->scopeSizes.pop_back();
}
bool Parser::IsVariableInScope(const std::string_view& identifier) const
{
return std::find(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), identifier) != m_context->identifiersInScope.rend();
}
void Parser::RegisterVariable(std::string identifier)
{
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.push_back(std::move(identifier));
}
const Token& Parser::Peek(std::size_t advance)
{
assert(m_context->tokenIndex + advance < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + advance];
}
std::vector<ShaderAst::Attribute> Parser::ParseAttributes()
std::vector<ShaderAst::ExprValue> Parser::ParseAttributes()
{
std::vector<ShaderAst::Attribute> attributes;
std::vector<ShaderAst::ExprValue> attributes;
Expect(Advance(), TokenType::OpenSquareBracket);
@@ -431,7 +247,7 @@ namespace Nz::ShaderLang
ShaderAst::AttributeType attributeType = ParseIdentifierAsAttributeType();
ShaderAst::Attribute::Param arg;
ShaderAst::ExprValue::Param arg;
if (Peek().type == TokenType::OpenParenthesis)
{
Consume();
@@ -454,7 +270,7 @@ namespace Nz::ShaderLang
return attributes;
}
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionType& type, ShaderAst::ExpressionPtr& initialValue)
void Parser::ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue)
{
name = ParseIdentifierAsName();
@@ -464,10 +280,8 @@ namespace Nz::ShaderLang
type = ParseType();
}
else
type = ShaderAst::NoType{};
if (IsNoType(type) || Peek().type == TokenType::Assign)
if (!type.HasValue() || Peek().type == TokenType::Assign)
{
Expect(Advance(), TokenType::Assign);
initialValue = ParseExpression();
@@ -522,11 +336,10 @@ namespace Nz::ShaderLang
case TokenType::Identifier:
{
std::string constName;
ShaderAst::ExpressionType constType;
ShaderAst::ExpressionValue<ShaderAst::ExpressionType> constType;
ShaderAst::ExpressionPtr initialValue;
ParseVariableDeclaration(constName, constType, initialValue);
RegisterVariable(constName);
return ShaderBuilder::DeclareConst(std::move(constName), std::move(constType), std::move(initialValue));
}
@@ -552,14 +365,14 @@ namespace Nz::ShaderLang
return ShaderBuilder::Discard();
}
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::External);
Expect(Advance(), TokenType::OpenCurlyBracket);
std::unique_ptr<ShaderAst::DeclareExternalStatement> externalStatement = std::make_unique<ShaderAst::DeclareExternalStatement>();
ShaderAst::AttributeValue<bool> condition;
ShaderAst::ExpressionValue<bool> condition;
for (auto&& [attributeType, arg] : attributes)
{
@@ -624,8 +437,6 @@ namespace Nz::ShaderLang
extVar.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon);
extVar.type = ParseType();
RegisterVariable(extVar.name);
}
Expect(Advance(), TokenType::ClosingCurlyBracket);
@@ -636,7 +447,7 @@ namespace Nz::ShaderLang
return externalStatement;
}
ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseForDeclaration(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::For);
@@ -710,7 +521,7 @@ namespace Nz::ShaderLang
return ParseStatementList();
}
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::FunctionDeclaration);
@@ -738,24 +549,18 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingParenthesis);
ShaderAst::ExpressionType returnType;
ShaderAst::ExpressionValue<ShaderAst::ExpressionType> returnType;
if (Peek().type == TokenType::Arrow)
{
Consume();
returnType = ParseType();
}
EnterScope();
for (const auto& parameter : parameters)
RegisterVariable(parameter.name);
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
LeaveScope();
auto func = ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
ShaderAst::AttributeValue<bool> condition;
ShaderAst::ExpressionValue<bool> condition;
for (auto&& [attributeType, arg] : attributes)
{
@@ -794,7 +599,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon);
ShaderAst::ExpressionType parameterType = ParseType();
ShaderAst::ExpressionPtr parameterType = ParseType();
return { parameterName, std::move(parameterType) };
}
@@ -807,7 +612,7 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon);
ShaderAst::ExpressionType optionType = ParseType();
ShaderAst::ExpressionPtr optionType = ParseType();
ShaderAst::ExpressionPtr initialValue;
if (Peek().type == TokenType::Assign)
@@ -837,7 +642,7 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseSingleStatement()
{
std::vector<ShaderAst::Attribute> attributes;
std::vector<ShaderAst::ExprValue> attributes;
ShaderAst::StatementPtr statement;
do
{
@@ -912,15 +717,13 @@ namespace Nz::ShaderLang
ShaderAst::StatementPtr Parser::ParseStatement()
{
if (Peek().type == TokenType::OpenCurlyBracket)
return ShaderBuilder::MultiStatement(ParseStatementList());
return ShaderBuilder::Scoped(ShaderBuilder::MultiStatement(ParseStatementList()));
else
return ParseSingleStatement();
}
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
{
EnterScope();
Expect(Advance(), TokenType::OpenCurlyBracket);
std::vector<ShaderAst::StatementPtr> statements;
@@ -931,19 +734,17 @@ namespace Nz::ShaderLang
}
Consume(); //< Consume closing curly bracket
LeaveScope();
return statements;
}
ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::Struct);
ShaderAst::StructDescription description;
description.name = ParseIdentifierAsName();
ShaderAst::AttributeValue<bool> condition;
ShaderAst::ExpressionValue<bool> condition;
for (auto&& [attributeType, attributeParam] : attributes)
{
@@ -1065,16 +866,15 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Let);
std::string variableName;
ShaderAst::ExpressionType variableType;
ShaderAst::ExpressionValue<ShaderAst::ExpressionType> variableType;
ShaderAst::ExpressionPtr expression;
ParseVariableDeclaration(variableName, variableType, expression);
RegisterVariable(variableName);
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector<ShaderAst::Attribute> attributes)
ShaderAst::StatementPtr Parser::ParseWhileStatement(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::While);
@@ -1133,19 +933,6 @@ namespace Nz::ShaderLang
accessMemberNode->identifiers.push_back(ParseIdentifierAsName());
} while (Peek().type == TokenType::Dot);
// FIXME
if (!accessMemberNode->identifiers.empty() && accessMemberNode->identifiers.front() == "Sample")
{
if (Peek().type == TokenType::OpenParenthesis)
{
auto parameters = ParseParameters();
parameters.insert(parameters.begin(), std::move(accessMemberNode->expr));
lhs = ShaderBuilder::Intrinsic(ShaderAst::IntrinsicType::SampleTexture, std::move(parameters));
break;
}
}
lhs = std::move(accessMemberNode);
}
else
@@ -1160,10 +947,10 @@ namespace Nz::ShaderLang
Consume();
indexNode->indices.push_back(ParseExpression());
Expect(Advance(), TokenType::ClosingSquareBracket);
}
while (Peek().type == TokenType::OpenSquareBracket);
while (Peek().type == TokenType::Comma);
Expect(Advance(), TokenType::ClosingSquareBracket);
lhs = std::move(indexNode);
}
@@ -1171,6 +958,15 @@ namespace Nz::ShaderLang
currentTokenType = Peek().type;
}
if (currentTokenType == TokenType::OpenParenthesis)
{
// Function call
auto parameters = ParseParameters();
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters));
c = true;
}
if (c)
continue;
@@ -1302,23 +1098,7 @@ namespace Nz::ShaderLang
return ParseFloatingPointExpression();
case TokenType::Identifier:
{
const std::string& identifier = std::get<std::string>(token.data);
// Is it a cast?
std::optional<ShaderAst::ExpressionType> exprType = DecodeType(identifier);
if (exprType)
return ShaderBuilder::Cast(std::move(*exprType), ParseParameters());
if (Peek(1).type == TokenType::OpenParenthesis)
{
// Function call
Consume();
return ShaderBuilder::CallFunction(identifier, ParseParameters());
}
else
return ParseIdentifier();
}
return ParseIdentifier();
case TokenType::IntegerValue:
return ParseIntegerExpression();
@@ -1370,28 +1150,10 @@ namespace Nz::ShaderLang
const std::string& Parser::ParseIdentifierAsName()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
auto it = s_identifierToBasicType.find(identifier);
if (it != s_identifierToBasicType.end())
throw ReservedKeyword{};
return identifier;
return std::get<std::string>(identifierToken.data);
}
ShaderAst::PrimitiveType Parser::ParsePrimitiveType()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
auto it = s_identifierToBasicType.find(identifier);
if (it == s_identifierToBasicType.end())
throw UnknownType{};
return it->second;
}
ShaderAst::ExpressionType Parser::ParseType()
ShaderAst::ExpressionPtr Parser::ParseType()
{
// Handle () as no type
if (Peek().type == TokenType::OpenParenthesis)
@@ -1399,20 +1161,10 @@ namespace Nz::ShaderLang
Consume();
Expect(Advance(), TokenType::ClosingParenthesis);
return ShaderAst::NoType{};
return ShaderBuilder::Constant(ShaderAst::NoValue{});
}
const Token& identifierToken = Expect(Peek(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
auto type = DecodeType(identifier);
if (!type)
{
Consume();
return ShaderAst::IdentifierType{ identifier };
}
return *std::move(type);
return ParseExpression();
}
int Parser::GetTokenPrecedence(TokenType token)
@@ -1433,6 +1185,7 @@ namespace Nz::ShaderLang
case TokenType::NotEqual: return 50;
case TokenType::Plus: return 60;
case TokenType::OpenSquareBracket: return 100;
case TokenType::OpenParenthesis: return 100;
default: return -1;
}
}

View File

@@ -400,8 +400,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
{
assert(std::holds_alternative<std::size_t>(node.targetFunction));
std::size_t functionIndex = std::get<std::size_t>(node.targetFunction);
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
UInt32 funcId = 0;
for (const auto& [funcIndex, func] : m_funcData)
@@ -443,7 +442,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
{
const ShaderAst::ExpressionType& targetExprType = node.targetType;
const ShaderAst::ExpressionType& targetExprType = node.targetType.GetResultingValue();
if (IsPrimitiveType(targetExprType))
{
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
@@ -584,7 +583,7 @@ namespace Nz
std::size_t varIndex = *node.varIndex;
for (auto&& extVar : node.externalVars)
RegisterExternalVariable(varIndex++, extVar.type);
RegisterExternalVariable(varIndex++, extVar.type.GetResultingValue());
}
void SpirvAstVisitor::Visit(ShaderAst::DeclareFunctionStatement& node)
@@ -674,7 +673,7 @@ namespace Nz
{
const auto& func = m_funcData[m_funcIndex];
UInt32 typeId = m_writer.GetTypeId(node.varType);
UInt32 typeId = m_writer.GetTypeId(node.varType.GetResultingValue());
assert(node.varIndex);
auto varIt = func.varIndexToVarId.find(*node.varIndex);
@@ -932,6 +931,11 @@ namespace Nz
m_currentBlock->Append(SpirvOp::OpReturn);
}
void SpirvAstVisitor::Visit(ShaderAst::ScopedStatement& node)
{
node.statement->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression);

View File

@@ -513,7 +513,7 @@ namespace Nz
for (const Structure::Member& member : structData.members)
{
member.offset = std::visit([&](auto&& arg) -> std::size_t
member.offset = SafeCast<UInt32>(std::visit([&](auto&& arg) -> std::size_t
{
using T = std::decay_t<decltype(arg)>;
@@ -601,7 +601,7 @@ namespace Nz
throw std::runtime_error("unexpected void as struct member");
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, member.type->type);
}, member.type->type));
}
return structOffsets;
@@ -671,7 +671,7 @@ namespace Nz
return std::make_shared<Type>(Array{
builtContainedType,
BuildConstant(type.length.GetResultingValue()),
BuildConstant(type.length),
arrayStride
});
}
@@ -802,7 +802,7 @@ namespace Nz
auto& sMembers = sType.members.emplace_back();
sMembers.name = member.name;
sMembers.type = BuildType(member.type);
sMembers.type = BuildType(member.type.GetResultingValue());
}
m_internal->isInBlockStruct = wasInBlock;
@@ -817,8 +817,7 @@ namespace Nz
auto SpirvConstantCache::BuildType(const ShaderAst::UniformType& type) const -> TypePtr
{
assert(std::holds_alternative<ShaderAst::StructType>(type.containedType));
return BuildType(std::get<ShaderAst::StructType>(type.containedType));
return BuildType(type.containedType);
}
UInt32 SpirvConstantCache::GetId(const Constant& c)
@@ -918,12 +917,12 @@ namespace Nz
return fieldOffsets.AddFieldArray(TypeToStructFieldType(type), arrayLength);
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Function& type, std::size_t arrayLength) const
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Function& /*type*/, std::size_t /*arrayLength*/) const
{
throw std::runtime_error("unexpected Function");
}
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Image& type, std::size_t arrayLength) const
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Image& /*type*/, std::size_t /*arrayLength*/) const
{
throw std::runtime_error("unexpected Image");
}

View File

@@ -30,9 +30,46 @@ namespace Nz
return resultId;
},
[this](const PointerChainAccess& pointerChainAccess) -> UInt32
{
UInt32 pointerType = m_writer.RegisterPointerType(*pointerChainAccess.exprType, pointerChainAccess.storage); //< FIXME: We shouldn't register this so late
UInt32 pointerId = m_visitor.AllocateResultId();
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(pointerId);
appender(pointerChainAccess.pointerId);
for (UInt32 id : pointerChainAccess.indices)
appender(id);
});
UInt32 resultId = m_visitor.AllocateResultId();
m_block.Append(SpirvOp::OpLoad, m_writer.GetTypeId(*pointerChainAccess.exprType), resultId, pointerId);
return resultId;
},
[](const Value& value) -> UInt32
{
return value.resultId;
return value.valueId;
},
[this](const ValueExtraction& extractedValue) -> UInt32
{
UInt32 resultId = m_visitor.AllocateResultId();
m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
{
appender(extractedValue.typeId);
appender(resultId);
appender(extractedValue.valueId);
for (UInt32 id : extractedValue.indices)
appender(id);
});
return resultId;
},
[](std::monostate) -> UInt32
{
@@ -47,48 +84,42 @@ namespace Nz
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 typeId = m_writer.GetTypeId(exprType);
assert(node.indices.size() == 1);
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
std::visit(overloaded
{
[&](const Pointer& pointer)
{
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
PointerChainAccess pointerChainAccess;
pointerChainAccess.exprType = &exprType;
pointerChainAccess.indices = { indexId };
pointerChainAccess.pointedTypeId = pointer.pointedTypeId;
pointerChainAccess.pointerId = pointer.pointerId;
pointerChainAccess.storage = pointer.storage;
StackArray<UInt32> indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size());
for (std::size_t i = 0; i < node.indices.size(); ++i)
indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]);
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
appender(pointer.pointerId);
for (UInt32 id : indexIds)
appender(id);
});
m_value = Pointer { pointer.storage, resultId, typeId };
m_value = std::move(pointerChainAccess);
},
[&](PointerChainAccess& pointerChainAccess)
{
pointerChainAccess.exprType = &exprType;
pointerChainAccess.indices.push_back(indexId);
},
[&](const Value& value)
{
StackArray<UInt32> indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size());
for (std::size_t i = 0; i < node.indices.size(); ++i)
indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]);
ValueExtraction extractedValue;
extractedValue.indices = { indexId };
extractedValue.typeId = typeId;
extractedValue.valueId = value.valueId;
m_block.AppendVariadic(SpirvOp::OpCompositeExtract, [&](const auto& appender)
{
appender(typeId);
appender(resultId);
appender(value.resultId);
for (UInt32 id : indexIds)
appender(id);
});
m_value = Value { resultId };
m_value = std::move(extractedValue);
},
[&](ValueExtraction& extractedValue)
{
extractedValue.indices.push_back(indexId);
extractedValue.typeId = typeId;
},
[](std::monostate)
{

View File

@@ -100,19 +100,10 @@ namespace Nz
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
StackArray<UInt32> indexIds = NazaraStackArrayNoInit(UInt32, node.indices.size());
for (std::size_t i = 0; i < node.indices.size(); ++i)
indexIds[i] = m_visitor.EvaluateExpression(node.indices[i]);
assert(node.indices.size() == 1);
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
m_block.AppendVariadic(SpirvOp::OpAccessChain, [&](const auto& appender)
{
appender(pointerType);
appender(resultId);
appender(pointer.pointerId);
for (UInt32 id : indexIds)
appender(id);
});
m_block.Append(SpirvOp::OpAccessChain, pointerType, resultId, pointer.pointerId, indexId);
m_value = Pointer { pointer.storage, resultId };
},
@@ -147,6 +138,8 @@ namespace Nz
{
// Swizzle the swizzle, keep common components
std::array<UInt32, 4> newIndices;
newIndices.fill(0); //< keep compiler happy
for (std::size_t i = 0; i < node.componentCount; ++i)
{
assert(node.components[i] < swizzledPointer.componentCount);

View File

@@ -144,17 +144,18 @@ namespace Nz
SpirvConstantCache::Variable variable;
variable.debugName = extVar.name;
if (ShaderAst::IsSamplerType(extVar.type))
const ShaderAst::ExpressionType& extVarType = extVar.type.GetResultingValue();
if (ShaderAst::IsSamplerType(extVarType))
{
variable.storageClass = SpirvStorageClass::UniformConstant;
variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass);
variable.type = m_constantCache.BuildPointerType(extVarType, variable.storageClass);
}
else
{
assert(ShaderAst::IsUniformType(extVar.type));
const auto& uniformType = std::get<ShaderAst::UniformType>(extVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniformType.containedType));
const auto& structType = std::get<ShaderAst::StructType>(uniformType.containedType);
assert(ShaderAst::IsUniformType(extVarType));
const auto& uniformType = std::get<ShaderAst::UniformType>(extVarType);
const auto& structType = uniformType.containedType;
assert(structType.structIndex < declaredStructs.size());
const auto& type = m_constantCache.BuildType(*declaredStructs[structType.structIndex], { SpirvDecoration::Block });
@@ -188,16 +189,27 @@ namespace Nz
{
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
parameterTypes.push_back(parameter.type.GetResultingValue());
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(node.returnType, parameterTypes));
if (node.returnType.HasValue())
{
const auto& returnType = node.returnType.GetResultingValue();
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(returnType));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(returnType, parameterTypes));
}
else
{
funcData.returnTypeId = m_constantCache.Register(*m_constantCache.BuildType(ShaderAst::NoType{}));
funcData.funcTypeId = m_constantCache.Register(*m_constantCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes));
}
for (auto& parameter : node.parameters)
{
const auto& parameterType = parameter.type.GetResultingValue();
auto& funcParam = funcData.parameters.emplace_back();
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function));
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameter.type));
funcParam.pointerTypeId = m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function));
funcParam.typeId = m_constantCache.Register(*m_constantCache.BuildType(parameterType));
}
}
else
@@ -235,9 +247,11 @@ namespace Nz
{
assert(node.parameters.size() == 1);
auto& parameter = node.parameters.front();
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
const auto& parameterType = parameter.type.GetResultingValue();
std::size_t structIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
assert(std::holds_alternative<ShaderAst::StructType>(parameterType));
std::size_t structIndex = std::get<ShaderAst::StructType>(parameterType).structIndex;
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
@@ -250,7 +264,7 @@ namespace Nz
{
inputs.push_back({
m_constantCache.Register(*m_constantCache.BuildConstant(Int32(memberIndex))),
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildPointerType(member.type.GetResultingValue(), SpirvStorageClass::Function)),
varId
});
}
@@ -259,18 +273,20 @@ namespace Nz
}
inputStruct = EntryPoint::InputStruct{
m_constantCache.Register(*m_constantCache.BuildPointerType(parameter.type, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildType(parameter.type))
m_constantCache.Register(*m_constantCache.BuildPointerType(parameterType, SpirvStorageClass::Function)),
m_constantCache.Register(*m_constantCache.BuildType(parameter.type.GetResultingValue()))
};
}
std::optional<UInt32> outputStructId;
std::vector<EntryPoint::Output> outputs;
if (!IsNoType(node.returnType))
if (node.returnType.HasValue())
{
assert(std::holds_alternative<ShaderAst::StructType>(node.returnType));
const ShaderAst::ExpressionType& returnType = node.returnType.GetResultingValue();
std::size_t structIndex = std::get<ShaderAst::StructType>(node.returnType).structIndex;
assert(std::holds_alternative<ShaderAst::StructType>(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex];
std::size_t memberIndex = 0;
@@ -283,7 +299,7 @@ namespace Nz
{
outputs.push_back({
Int32(memberIndex),
m_constantCache.Register(*m_constantCache.BuildType(member.type)),
m_constantCache.Register(*m_constantCache.BuildType(member.type.GetResultingValue())),
varId
});
}
@@ -291,7 +307,7 @@ namespace Nz
memberIndex++;
}
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(node.returnType));
outputStructId = m_constantCache.Register(*m_constantCache.BuildType(returnType));
}
funcData.entryPointData = EntryPoint{
@@ -334,7 +350,7 @@ namespace Nz
func.varIndexToVarId[*node.varIndex] = func.variables.size();
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType, SpirvStorageClass::Function));
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(node.varType.GetResultingValue(), SpirvStorageClass::Function));
}
void Visit(ShaderAst::IdentifierExpression& node) override
@@ -408,7 +424,7 @@ namespace Nz
variable.debugName = builtin.debugName;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass);
UInt32 varId = m_constantCache.Register(variable);
builtinDecorations[varId] = builtinDecoration;
@@ -421,7 +437,7 @@ namespace Nz
variable.debugName = member.name;
variable.funcId = funcIndex;
variable.storageClass = storageClass;
variable.type = m_constantCache.BuildPointerType(member.type, storageClass);
variable.type = m_constantCache.BuildPointerType(member.type.GetResultingValue(), storageClass);
UInt32 varId = m_constantCache.Register(variable);
locationDecorations[varId] = member.locationIndex.GetResultingValue();
@@ -643,9 +659,12 @@ namespace Nz
parameterTypes.reserve(functionNode.parameters.size());
for (const auto& parameter : functionNode.parameters)
parameterTypes.push_back(parameter.type);
parameterTypes.push_back(parameter.type.GetResultingValue());
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType, parameterTypes);
if (functionNode.returnType.HasValue())
return m_currentState->constantTypeCache.BuildFunctionType(functionNode.returnType.GetResultingValue(), parameterTypes);
else
return m_currentState->constantTypeCache.BuildFunctionType(ShaderAst::NoType{}, parameterTypes);
}
UInt32 SpirvWriter::GetConstantId(const ShaderAst::ConstantValue& value) const