ShaderLang: Proof of concept (add support for a lot of things)

This commit is contained in:
Jérôme Leclercq
2021-03-31 10:21:35 +02:00
parent 2a73005295
commit c1d1838336
37 changed files with 2259 additions and 908 deletions

View File

@@ -20,37 +20,43 @@ namespace Nz
namespace
{
static const char* flipYUniformName = "_NzFlipValue";
static const char* overridenMain = "_NzMain";
struct AstAdapter : ShaderAst::AstCloner
{
void Visit(ShaderAst::AssignExpression& node) override
using AstCloner::Clone;
std::unique_ptr<ShaderAst::DeclareFunctionStatement> Clone(ShaderAst::DeclareFunctionStatement& node) override
{
if (!flipYPosition)
return AstCloner::Visit(node);
auto clone = AstCloner::Clone(node);
if (clone->name == "main")
clone->name = "_NzMain";
if (node.left->GetType() != ShaderAst::NodeType::IdentifierExpression)
return AstCloner::Visit(node);
/*
FIXME:
const auto& identifier = static_cast<const ShaderAst::Identifier&>(*node.left);
if (identifier.var->GetType() != ShaderAst::VariableType::BuiltinVariable)
return ShaderAstCloner::Visit(node);
const auto& builtinVar = static_cast<const ShaderAst::BuiltinVariable&>(*identifier.var);
if (builtinVar.entry != ShaderAst::BuiltinEntry::VertexPosition)
return ShaderAstCloner::Visit(node);
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderAst::BasicType::Float1);
auto oneConstant = ShaderBuilder::Constant(1.f);
auto fixYValue = ShaderBuilder::Cast<ShaderAst::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
auto mulFix = ShaderBuilder::Multiply(CloneExpression(node.right), fixYValue);
PushExpression(ShaderAst::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));*/
return clone;
}
bool flipYPosition = false;
void Visit(ShaderAst::DeclareFunctionStatement& node)
{
if (removedEntryPoints.find(&node) != removedEntryPoints.end())
{
PushStatement(ShaderBuilder::NoOp());
return;
}
AstCloner::Visit(node);
}
std::unordered_set<ShaderAst::DeclareFunctionStatement*> removedEntryPoints;
};
struct Builtin
{
std::string identifier;
ShaderStageTypeFlags stageFlags;
};
std::unordered_map<std::string, Builtin> builtinMapping = {
{ "position", { "gl_Position", ShaderStageType::Vertex } }
};
}
@@ -59,6 +65,7 @@ namespace Nz
{
const States* states = nullptr;
ShaderAst::AstCache cache;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
unsigned int indentLevel = 0;
};
@@ -69,19 +76,8 @@ namespace Nz
{
}
std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
std::string GlslWriter::Generate(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader, const States& conditions)
{
/*const ShaderAst* selectedShader = &inputShader;
std::optional<ShaderAst> modifiedShader;
if (inputShader.GetStage() == ShaderStageType::Vertex && m_environment.flipYPosition)
{
modifiedShader.emplace(inputShader);
modifiedShader->AddUniform(flipYUniformName, ShaderAst::BasicType::Float1);
selectedShader = &modifiedShader.value();
}*/
State state;
m_currentState = &state;
CallOnExit onExit([this]()
@@ -93,6 +89,27 @@ namespace Nz
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
throw std::runtime_error("Invalid shader AST: " + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
if (!state.entryFunc)
throw std::runtime_error("missing entry point");
AstAdapter adapter;
for (ShaderAst::DeclareFunctionStatement* entryFunc : state.cache.entryFunctions)
{
if (entryFunc != state.entryFunc)
adapter.removedEntryPoints.insert(entryFunc);
}
ShaderAst::StatementPtr adaptedShader = adapter.Clone(shader);
state.cache.Clear();
if (!ShaderAst::ValidateAst(adaptedShader, &error, &state.cache))
throw std::runtime_error("Internal error:" + error);
state.entryFunc = state.cache.entryFunctions[UnderlyingCast(shaderStage)];
assert(state.entryFunc);
unsigned int glslVersion;
if (m_environment.glES)
{
@@ -141,14 +158,14 @@ namespace Nz
if (!m_environment.glES && m_environment.extCallback)
{
// GL_ARB_shading_language_420pack (required for layout(binding = X))
if (glslVersion < 420 && HasExplicitBinding(shader))
if (glslVersion < 420 && HasExplicitBinding(adaptedShader))
{
if (m_environment.extCallback("GL_ARB_shading_language_420pack"))
requiredExtensions.emplace_back("GL_ARB_shading_language_420pack");
}
// GL_ARB_separate_shader_objects (required for layout(location = X))
if (glslVersion < 410 && HasExplicitLocation(shader))
if (glslVersion < 410 && HasExplicitLocation(adaptedShader))
{
if (m_environment.extCallback("GL_ARB_separate_shader_objects"))
requiredExtensions.emplace_back("GL_ARB_separate_shader_objects");
@@ -173,7 +190,10 @@ namespace Nz
AppendLine();
}
shader->Visit(*this);
adaptedShader->Visit(*this);
// Append true GLSL entry point
AppendEntryPoint(shaderStage);
return state.stream.str();
}
@@ -188,7 +208,7 @@ namespace Nz
return flipYUniformName;
}
void GlslWriter::Append(ShaderAst::ShaderExpressionType type)
void GlslWriter::Append(const ShaderAst::ExpressionType& type)
{
std::visit([&](auto&& arg)
{
@@ -206,29 +226,82 @@ namespace Nz
}
}
void GlslWriter::Append(ShaderAst::BasicType type)
void GlslWriter::Append(const ShaderAst::IdentifierType& identifierType)
{
Append(identifierType.name);
}
void GlslWriter::Append(const ShaderAst::MatrixType& matrixType)
{
if (matrixType.columnCount == matrixType.rowCount)
{
Append("mat");
Append(matrixType.columnCount);
}
else
{
Append("mat");
Append(matrixType.columnCount);
Append("x");
Append(matrixType.rowCount);
}
}
void GlslWriter::Append(ShaderAst::PrimitiveType type)
{
switch (type)
{
case ShaderAst::BasicType::Boolean: return Append("bool");
case ShaderAst::BasicType::Float1: return Append("float");
case ShaderAst::BasicType::Float2: return Append("vec2");
case ShaderAst::BasicType::Float3: return Append("vec3");
case ShaderAst::BasicType::Float4: return Append("vec4");
case ShaderAst::BasicType::Int1: return Append("int");
case ShaderAst::BasicType::Int2: return Append("ivec2");
case ShaderAst::BasicType::Int3: return Append("ivec3");
case ShaderAst::BasicType::Int4: return Append("ivec4");
case ShaderAst::BasicType::Mat4x4: return Append("mat4");
case ShaderAst::BasicType::Sampler2D: return Append("sampler2D");
case ShaderAst::BasicType::UInt1: return Append("uint");
case ShaderAst::BasicType::UInt2: return Append("uvec2");
case ShaderAst::BasicType::UInt3: return Append("uvec3");
case ShaderAst::BasicType::UInt4: return Append("uvec4");
case ShaderAst::BasicType::Void: return Append("void");
case ShaderAst::PrimitiveType::Boolean: return Append("bool");
case ShaderAst::PrimitiveType::Float32: return Append("float");
case ShaderAst::PrimitiveType::Int32: return Append("ivec2");
case ShaderAst::PrimitiveType::UInt32: return Append("uint");
}
}
void GlslWriter::Append(const ShaderAst::SamplerType& samplerType)
{
switch (samplerType.sampledType)
{
case ShaderAst::PrimitiveType::Boolean:
case ShaderAst::PrimitiveType::Float32:
break;
case ShaderAst::PrimitiveType::Int32: Append("i"); break;
case ShaderAst::PrimitiveType::UInt32: Append("u"); break;
}
Append("sampler");
switch (samplerType.dim)
{
case ImageType_1D: Append("1D"); break;
case ImageType_1D_Array: Append("1DArray"); break;
case ImageType_2D: Append("2D"); break;
case ImageType_2D_Array: Append("2DArray"); break;
case ImageType_3D: Append("3D"); break;
case ImageType_Cubemap: Append("Cube"); break;
}
}
void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
{
/* TODO */
}
void GlslWriter::Append(const ShaderAst::VectorType& vecType)
{
switch (vecType.type)
{
case ShaderAst::PrimitiveType::Boolean: Append("b"); break;
case ShaderAst::PrimitiveType::Float32: break;
case ShaderAst::PrimitiveType::Int32: Append("i"); break;
case ShaderAst::PrimitiveType::UInt32: Append("u"); break;
}
Append("vec");
Append(vecType.componentCount);
}
void GlslWriter::Append(ShaderAst::MemoryLayout layout)
{
switch (layout)
@@ -239,6 +312,11 @@ namespace Nz
}
}
void GlslWriter::Append(ShaderAst::NoType)
{
return Append("void");
}
template<typename T>
void GlslWriter::Append(const T& param)
{
@@ -246,6 +324,12 @@ namespace Nz
m_currentState->stream << param;
}
template<typename T1, typename T2, typename... Args>
void GlslWriter::Append(const T1& firstParam, const T2& secondParam, Args&&... params)
{
Append(firstParam);
Append(secondParam, std::forward<Args>(params)...);
}
void GlslWriter::AppendCommentSection(const std::string& section)
{
@@ -256,6 +340,152 @@ namespace Nz
AppendLine();
}
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage)
{
AppendLine();
AppendLine("// Entry point handling");
struct InOutField
{
std::string name;
std::string targetName;
};
std::vector<InOutField> inputFields;
const ShaderAst::StructDescription* inputStruct = nullptr;
auto HandleInOutStructs = [this, shaderStage](const ShaderAst::ExpressionType& expressionType, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
{
assert(IsIdentifierType(expressionType));
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::IdentifierType>(expressionType).name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
for (const auto& member : s.members)
{
bool skip = false;
std::optional<std::string> builtinName;
std::optional<long long> attributeLocation;
for (const auto& [attributeType, attributeParam] : member.attributes)
{
if (attributeType == ShaderAst::AttributeType::Builtin)
{
auto it = builtinMapping.find(std::get<std::string>(attributeParam));
if (it != builtinMapping.end())
{
const Builtin& builtin = it->second;
if (!builtin.stageFlags.Test(shaderStage))
{
skip = true;
break;
}
builtinName = builtin.identifier;
break;
}
}
else if (attributeType == ShaderAst::AttributeType::Location)
{
attributeLocation = std::get<long long>(attributeParam);
break;
}
}
if (!skip && attributeLocation)
{
Append("layout(location = ");
Append(*attributeLocation);
Append(") ");
Append(keyword);
Append(" ");
Append(member.type);
Append(" ");
Append(targetPrefix);
Append(member.name);
AppendLine(";");
fields.push_back({
fromPrefix + member.name,
targetPrefix + member.name
});
}
else if (builtinName)
{
fields.push_back({
fromPrefix + member.name,
*builtinName
});
}
}
AppendLine();
return &s;
};
if (!m_currentState->entryFunc->parameters.empty())
{
assert(m_currentState->entryFunc->parameters.size() == 1);
const auto& parameter = m_currentState->entryFunc->parameters.front();
inputStruct = HandleInOutStructs(parameter.type, inputFields, "in", "_nzInput.", "_NzIn_");
}
std::vector<InOutField> outputFields;
const ShaderAst::StructDescription* outputStruct = nullptr;
if (!IsNoType(m_currentState->entryFunc->returnType))
outputStruct = HandleInOutStructs(m_currentState->entryFunc->returnType, outputFields, "out", "_nzOutput.", "_NzOut_");
if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition)
AppendLine("uniform float ", flipYUniformName, ";");
AppendLine("void main()");
EnterScope();
{
if (inputStruct)
{
Append(inputStruct->name);
AppendLine(" _nzInput;");
for (const auto& [name, targetName] : inputFields)
{
AppendLine(name, " = ", targetName, ";");
}
AppendLine();
}
if (outputStruct)
Append(outputStruct->name, " _nzOutput = ");
Append(m_currentState->entryFunc->name);
Append("(");
if (m_currentState->entryFunc)
Append("_nzInput");
Append(");");
if (outputStruct)
{
AppendLine();
for (const auto& [name, targetName] : outputFields)
{
bool isOutputPosition = (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition && targetName == "gl_Position");
AppendLine();
Append(targetName, " = ", name);
if (isOutputPosition)
Append(" * vec4(1.0, ", flipYUniformName, ", 1.0, 1.0)");
Append(";");
}
}
}
LeaveScope();
}
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
Append(".");
@@ -273,7 +503,7 @@ namespace Nz
const auto& member = *memberIt;
if (remainingMembers > 1)
AppendField(scopeId, std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
}
void GlslWriter::AppendLine(const std::string& txt)
@@ -283,6 +513,13 @@ namespace Nz
m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t');
}
template<typename... Args>
void GlslWriter::AppendLine(Args&&... params)
{
(Append(std::forward<Args>(params)), ...);
AppendLine();
}
void GlslWriter::EnterScope()
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@@ -291,13 +528,17 @@ namespace Nz
AppendLine("{");
}
void GlslWriter::LeaveScope()
void GlslWriter::LeaveScope(bool skipLine)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
m_currentState->indentLevel--;
AppendLine();
AppendLine("}");
if (skipLine)
AppendLine("}");
else
Append("}");
}
void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
@@ -317,12 +558,12 @@ namespace Nz
{
Visit(node.structExpr, true);
const ShaderAst::ShaderExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
assert(IsStructType(exprType));
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
assert(IsIdentifierType(exprType));
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
AppendField(scopeId, std::get<std::string>(exprType), node.memberIdentifiers.data(), node.memberIdentifiers.size());
AppendField(scopeId, std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
@@ -336,7 +577,7 @@ namespace Nz
break;
}
node.left->Visit(*this);
node.right->Visit(*this);
}
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
@@ -455,6 +696,71 @@ namespace Nz
}, node.value);
}
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
bool isStd140 = false;
for (const auto& [attributeType, attributeParam] : externalVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
bindingIndex = std::get<long long>(attributeParam);
else if (attributeType == ShaderAst::AttributeType::Layout)
{
if (std::get<std::string>(attributeParam) == "std140")
isStd140 = true;
}
}
if (bindingIndex)
{
Append("layout(binding = ");
Append(*bindingIndex);
if (isStd140)
Append(", std140");
Append(") uniform ");
if (IsUniformType(externalVar.type))
{
Append("_NzBinding_");
AppendLine(externalVar.name);
EnterScope();
{
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(0, std::get<ShaderAst::UniformType>(externalVar.type).containedType.name);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
bool first = true;
for (const auto& [name, attribute, type] : s.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(" ");
Append(name);
Append(";");
}
}
LeaveScope(false);
}
else
Append(externalVar.type);
Append(" ");
Append(externalVar.name);
AppendLine(";");
}
}
}
void GlslWriter::Visit(ShaderAst::DeclareFunctionStatement& node)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@@ -475,15 +781,36 @@ namespace Nz
EnterScope();
{
AstAdapter adapter;
adapter.flipYPosition = m_environment.flipYPosition;
for (auto& statement : node.statements)
adapter.Clone(statement)->Visit(*this);
statement->Visit(*this);
}
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
Append("struct ");
AppendLine(node.description.name);
EnterScope();
{
bool first = true;
for (const auto& [name, attribute, type] : node.description.members)
{
if (!first)
AppendLine();
first = false;
Append(type);
Append(" ");
Append(name);
Append(";");
}
}
LeaveScope(false);
AppendLine(";");
}
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
Append(node.varType);
@@ -506,7 +833,7 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::ExpressionStatement& node)
{
node.expression->Visit(*this);
Append(";");
AppendLine(";");
}
void GlslWriter::Visit(ShaderAst::IdentifierExpression& node)
@@ -525,6 +852,10 @@ namespace Nz
case ShaderAst::IntrinsicType::DotProduct:
Append("dot");
break;
case ShaderAst::IntrinsicType::SampleTexture:
Append("texture");
break;
}
Append("(");
@@ -624,4 +955,5 @@ namespace Nz
return false;
}
}

View File

@@ -42,6 +42,21 @@ namespace Nz::ShaderAst
return PopStatement();
}
std::unique_ptr<DeclareFunctionStatement> AstCloner::Clone(DeclareFunctionStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->attributes = node.attributes;
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
return clone;
}
void AstCloner::Visit(AccessMemberExpression& node)
{
auto clone = std::make_unique<AccessMemberExpression>();
@@ -162,21 +177,20 @@ namespace Nz::ShaderAst
PushStatement(std::move(clone));
}
void AstCloner::Visit(DeclareFunctionStatement& node)
void AstCloner::Visit(DeclareExternalStatement& node)
{
auto clone = std::make_unique<DeclareFunctionStatement>();
auto clone = std::make_unique<DeclareExternalStatement>();
clone->attributes = node.attributes;
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
clone->externalVars = node.externalVars;
PushStatement(std::move(clone));
}
void AstCloner::Visit(DeclareFunctionStatement& node)
{
PushStatement(Clone(node));
}
void AstCloner::Visit(DeclareStructStatement& node)
{
auto clone = std::make_unique<DeclareStructStatement>();

View File

@@ -9,16 +9,16 @@
namespace Nz::ShaderAst
{
ShaderExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache = nullptr)
ExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache)
{
m_cache = cache;
ShaderExpressionType type = GetExpressionTypeInternal(expression);
ExpressionType type = GetExpressionTypeInternal(expression);
m_cache = nullptr;
return type;
}
ShaderExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
ExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
{
m_lastExpressionType.reset();
@@ -28,6 +28,33 @@ namespace Nz::ShaderAst
return std::move(*m_lastExpressionType);
}
ExpressionType ExpressionTypeVisitor::ResolveAlias(Expression& expression, ExpressionType expressionType)
{
if (IsIdentifierType(expressionType))
{
auto scopeIt = m_cache->scopeIdByNode.find(&expression);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
if (identifier && std::holds_alternative<AstCache::Alias>(identifier->value))
{
const AstCache::Alias& alias = std::get<AstCache::Alias>(identifier->value);
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ExpressionType>)
return arg;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, alias.value);
}
}
return expressionType;
}
void ExpressionTypeVisitor::Visit(Expression& expression)
{
if (m_cache)
@@ -51,6 +78,16 @@ namespace Nz::ShaderAst
void ExpressionTypeVisitor::Visit(AccessMemberExpression& node)
{
auto scopeIt = m_cache->scopeIdByNode.find(&node);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
ExpressionType expressionType = ResolveAlias(node, GetExpressionTypeInternal(*node.structExpr));
if (!IsIdentifierType(expressionType))
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, std::get<IdentifierType>(expressionType).name);
throw std::runtime_error("unhandled accessmember expression");
}
@@ -70,38 +107,35 @@ namespace Nz::ShaderAst
case BinaryType::Divide:
case BinaryType::Multiply:
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.left);
assert(IsBasicType(leftExprType));
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.left));
ExpressionType rightExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.right));
ShaderExpressionType rightExprType = GetExpressionTypeInternal(*node.right);
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
if (IsPrimitiveType(leftExprType))
{
case BasicType::Boolean:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
m_lastExpressionType = std::move(leftExprType);
break;
switch (std::get<PrimitiveType>(leftExprType))
{
case PrimitiveType::Boolean:
m_lastExpressionType = std::move(leftExprType);
break;
case BasicType::Float1:
case BasicType::Int1:
case BasicType::Mat4x4:
case BasicType::UInt1:
m_lastExpressionType = std::move(rightExprType);
break;
case BasicType::Sampler2D:
case BasicType::Void:
break;
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
m_lastExpressionType = std::move(rightExprType);
break;
}
}
else if (IsMatrixType(leftExprType))
{
if (IsVectorType(rightExprType))
m_lastExpressionType = std::move(rightExprType);
else
m_lastExpressionType = std::move(leftExprType);
}
else if (IsVectorType(leftExprType))
m_lastExpressionType = std::move(leftExprType);
else
throw std::runtime_error("validation failure");
break;
}
@@ -112,7 +146,7 @@ namespace Nz::ShaderAst
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompNe:
m_lastExpressionType = BasicType::Boolean;
m_lastExpressionType = PrimitiveType::Boolean;
break;
}
}
@@ -124,38 +158,38 @@ namespace Nz::ShaderAst
void ExpressionTypeVisitor::Visit(ConditionalExpression& node)
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.truePath);
assert(leftExprType == GetExpressionTypeInternal(*node.falsePath));
ExpressionType leftExprType = ResolveAlias(node, GetExpressionTypeInternal(*node.truePath));
assert(leftExprType == ResolveAlias(node, GetExpressionTypeInternal(*node.falsePath)));
m_lastExpressionType = std::move(leftExprType);
}
void ExpressionTypeVisitor::Visit(ConstantExpression& node)
{
m_lastExpressionType = std::visit([&](auto&& arg)
m_lastExpressionType = std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return BasicType::Boolean;
return PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return BasicType::Float1;
return PrimitiveType::Float32;
else if constexpr (std::is_same_v<T, Int32>)
return BasicType::Int1;
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return BasicType::Int1;
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, Vector2f>)
return BasicType::Float2;
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)
return BasicType::Float3;
return VectorType{ 3, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector4f>)
return BasicType::Float4;
return VectorType{ 4, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector2i32>)
return BasicType::Int2;
return VectorType{ 2, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector3i32>)
return BasicType::Int3;
return VectorType{ 3, PrimitiveType::Int32 };
else if constexpr (std::is_same_v<T, Vector4i32>)
return BasicType::Int4;
return VectorType{ 4, PrimitiveType::Int32 };
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, node.value);
@@ -173,7 +207,7 @@ namespace Nz::ShaderAst
if (!identifier || !std::holds_alternative<AstCache::Variable>(identifier->value))
throw std::runtime_error("internal error");
m_lastExpressionType = std::get<AstCache::Variable>(identifier->value).type;
m_lastExpressionType = ResolveAlias(node, std::get<AstCache::Variable>(identifier->value).type);
}
void ExpressionTypeVisitor::Visit(IntrinsicExpression& node)
@@ -185,16 +219,40 @@ namespace Nz::ShaderAst
break;
case IntrinsicType::DotProduct:
m_lastExpressionType = BasicType::Float1;
m_lastExpressionType = PrimitiveType::Float32;
break;
case IntrinsicType::SampleTexture:
{
if (node.parameters.empty())
throw std::runtime_error("validation failure");
ExpressionType firstParamType = ResolveAlias(node, GetExpressionTypeInternal(*node.parameters.front()));
if (!IsSamplerType(firstParamType))
throw std::runtime_error("validation failure");
const auto& sampler = std::get<SamplerType>(firstParamType);
m_lastExpressionType = VectorType{
4,
sampler.sampledType
};
break;
}
}
}
void ExpressionTypeVisitor::Visit(SwizzleExpression& node)
{
ShaderExpressionType exprType = GetExpressionTypeInternal(*node.expression);
assert(IsBasicType(exprType));
ExpressionType exprType = GetExpressionTypeInternal(*node.expression);
m_lastExpressionType = static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + node.componentCount - 1);
if (IsMatrixType(exprType))
m_lastExpressionType = std::get<MatrixType>(exprType).type;
else if (IsVectorType(exprType))
m_lastExpressionType = std::get<VectorType>(exprType).type;
else
throw std::runtime_error("validation failure");
}
}

View File

@@ -453,8 +453,8 @@ namespace Nz::ShaderAst
{
auto& constant = static_cast<ConstantExpression&>(*cond);
assert(IsBasicType(GetExpressionType(constant)));
assert(std::get<BasicType>(GetExpressionType(constant)) == BasicType::Boolean);
assert(IsPrimitiveType(GetExpressionType(constant)));
assert(std::get<PrimitiveType>(GetExpressionType(constant)) == PrimitiveType::Boolean);
bool cValue = std::get<bool>(constant.value);
if (!cValue)

View File

@@ -79,6 +79,11 @@ namespace Nz::ShaderAst
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareExternalStatement& node)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(DeclareFunctionStatement& node)
{
for (auto& statement : node.statements)

View File

@@ -58,7 +58,7 @@ namespace Nz::ShaderAst
void AstSerializerBase::Serialize(CastExpression& node)
{
Enum(node.targetType);
Type(node.targetType);
for (auto& expr : node.expressions)
Node(expr);
}
@@ -152,17 +152,25 @@ namespace Nz::ShaderAst
Node(node.statement);
}
void AstSerializerBase::Serialize(DeclareExternalStatement& node)
{
Attributes(node.attributes);
Container(node.externalVars);
for (auto& extVar : node.externalVars)
{
Attributes(extVar.attributes);
Value(extVar.name);
Type(extVar.type);
}
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);
Type(node.returnType);
Container(node.attributes);
for (auto& attribute : node.attributes)
{
Enum(attribute.type);
Value(attribute.args);
}
Attributes(node.attributes);
Container(node.parameters);
for (auto& parameter : node.parameters)
@@ -223,6 +231,78 @@ namespace Nz::ShaderAst
m_stream.FlushBits();
}
void AstSerializerBase::Attributes(std::vector<Attribute>& attributes)
{
Container(attributes);
for (auto& attribute : attributes)
{
Enum(attribute.type);
if (IsWriting())
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::monostate>)
{
UInt8 typeId = 0;
Value(typeId);
}
else if constexpr (std::is_same_v<T, long long>)
{
UInt8 typeId = 1;
UInt64 v = UInt64(arg);
Value(typeId);
Value(v);
}
else if constexpr (std::is_same_v<T, std::string>)
{
UInt8 typeId = 2;
Value(typeId);
Value(arg);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, attribute.args);
}
else
{
UInt8 typeId;
Value(typeId);
switch (typeId)
{
case 0:
attribute.args.emplace<std::monostate>();
break;
case 1:
{
UInt64 arg;
Value(arg);
attribute.args = static_cast<long long>(arg);
break;
}
case 2:
{
std::string arg;
Value(arg);
attribute.args = std::move(arg);
break;
}
default:
throw std::runtime_error("invalid attribute type id");
}
}
}
}
bool ShaderAstSerializer::IsWriting() const
{
@@ -253,20 +333,47 @@ namespace Nz::ShaderAst
}
}
void ShaderAstSerializer::Type(ShaderExpressionType& type)
void ShaderAstSerializer::Type(ExpressionType& type)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, BasicType>)
{
if constexpr (std::is_same_v<T, NoType>)
m_stream << UInt8(0);
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, std::string>)
else if constexpr (std::is_same_v<T, PrimitiveType>)
{
m_stream << UInt8(1);
m_stream << arg;
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
m_stream << UInt8(2);
m_stream << arg.name;
}
else if constexpr (std::is_same_v<T, MatrixType>)
{
m_stream << UInt8(3);
m_stream << UInt32(arg.columnCount);
m_stream << UInt32(arg.rowCount);
m_stream << UInt32(arg.type);
}
else if constexpr (std::is_same_v<T, SamplerType>)
{
m_stream << UInt8(4);
m_stream << UInt32(arg.dim);
m_stream << UInt32(arg.sampledType);
}
else if constexpr (std::is_same_v<T, UniformType>)
{
m_stream << UInt8(5);
m_stream << arg.containedType.name;
}
else if constexpr (std::is_same_v<T, VectorType>)
{
m_stream << UInt8(6);
m_stream << UInt32(arg.componentCount);
m_stream << UInt32(arg.type);
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
@@ -421,28 +528,123 @@ namespace Nz::ShaderAst
}
}
void ShaderAstUnserializer::Type(ShaderExpressionType& type)
void ShaderAstUnserializer::Type(ExpressionType& type)
{
UInt8 typeIndex;
Value(typeIndex);
switch (typeIndex)
{
case 0: //< Primitive
/*
if constexpr (std::is_same_v<T, NoType>)
m_stream << UInt8(0);
else if constexpr (std::is_same_v<T, PrimitiveType>)
{
BasicType exprType;
Enum(exprType);
m_stream << UInt8(1);
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, IdentifierType>)
{
m_stream << UInt8(2);
m_stream << arg.name;
}
else if constexpr (std::is_same_v<T, MatrixType>)
{
m_stream << UInt8(3);
m_stream << UInt32(arg.columnCount);
m_stream << UInt32(arg.rowCount);
m_stream << UInt32(arg.type);
}
else if constexpr (std::is_same_v<T, SamplerType>)
{
m_stream << UInt8(4);
m_stream << UInt32(arg.dim);
m_stream << UInt32(arg.sampledType);
}
else if constexpr (std::is_same_v<T, VectorType>)
{
m_stream << UInt8(5);
m_stream << UInt32(arg.componentCount);
m_stream << UInt32(arg.type);
}
*/
type = exprType;
case 0: //< NoType
type = NoType{};
break;
case 1: //< PrimitiveType
{
PrimitiveType primitiveType;
Enum(primitiveType);
type = primitiveType;
break;
}
case 1: //< Struct (name)
case 2: //< Identifier
{
std::string structName;
Value(structName);
std::string identifier;
Value(identifier);
type = std::move(structName);
type = IdentifierType{ std::move(identifier) };
break;
}
case 3: //< MatrixType
{
UInt32 columnCount, rowCount;
PrimitiveType primitiveType;
Value(columnCount);
Value(rowCount);
Enum(primitiveType);
type = MatrixType {
columnCount,
rowCount,
primitiveType
};
break;
}
case 4: //< SamplerType
{
ImageType dim;
PrimitiveType sampledType;
Enum(dim);
Enum(sampledType);
type = SamplerType {
dim,
sampledType
};
break;
}
case 5: //< UniformType
{
std::string containedType;
Value(containedType);
type = UniformType {
IdentifierType {
containedType
}
};
break;
}
case 6: //< VectorType
{
UInt32 componentCount;
PrimitiveType componentType;
Value(componentCount);
Enum(componentType);
type = VectorType{
componentCount,
componentType
};
break;
}

View File

@@ -18,7 +18,6 @@ namespace Nz::ShaderAst
{ "frag", ShaderStageType::Fragment },
{ "vert", ShaderStageType::Vertex },
};
}
struct AstError
@@ -30,6 +29,8 @@ namespace Nz::ShaderAst
{
//const ShaderAst::Function* currentFunction;
std::optional<std::size_t> activeScopeId;
std::unordered_set<std::string> declaredExternalVar;
std::unordered_set<long long> usedBindingIndexes;;
AstCache* cache;
};
@@ -81,31 +82,31 @@ namespace Nz::ShaderAst
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
}
void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
void AstValidator::TypeMustMatch(const ExpressionType& left, const ExpressionType& right)
{
if (left != right)
throw AstError{ "Left expression type must match right expression type" };
}
ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
ExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
if (!identifier)
throw AstError{ "unknown identifier " + structName };
if (std::holds_alternative<StructDescription>(identifier->value))
if (!std::holds_alternative<StructDescription>(identifier->value))
throw AstError{ "identifier is not a struct" };
const StructDescription& s = std::get<StructDescription>(identifier->value);
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
auto memberIt = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
if (memberIt == s.members.end())
throw AstError{ "unknown field " + memberIdentifier[0]};
const auto& member = *memberIt;
if (remainingMembers > 1)
return CheckField(std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
return CheckField(std::get<IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
else
return member.type;
}
@@ -130,7 +131,7 @@ namespace Nz::ShaderAst
m_context->activeScopeId = previousScope.parentScopeIndex;
}
void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType)
void AstValidator::RegisterExpressionType(Expression& node, ExpressionType expressionType)
{
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
}
@@ -145,11 +146,14 @@ namespace Nz::ShaderAst
{
RegisterScope(node);
ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
if (!IsStructType(exprType))
// Register expressions types
AstRecursiveVisitor::Visit(node);
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
if (!IsIdentifierType(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<std::string>(exprType);
const std::string& structName = std::get<IdentifierType>(exprType).name;
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
}
@@ -160,12 +164,14 @@ namespace Nz::ShaderAst
MandatoryExpr(node.left);
MandatoryExpr(node.right);
// Register expressions types
AstRecursiveVisitor::Visit(node);
TypeMustMatch(node.left, node.right);
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError { "Assignation is only possible with a l-value" };
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(BinaryExpression& node)
@@ -175,80 +181,121 @@ namespace Nz::ShaderAst
// Register expression type
AstRecursiveVisitor::Visit(node);
ShaderExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
if (!IsBasicType(leftExprType))
ExpressionType leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
if (!IsPrimitiveType(leftExprType) && !IsMatrixType(leftExprType) && !IsVectorType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
ShaderExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
if (!IsBasicType(rightExprType))
ExpressionType rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
if (!IsPrimitiveType(rightExprType) && !IsMatrixType(rightExprType) && !IsVectorType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
BasicType leftType = std::get<BasicType>(leftExprType);
BasicType rightType = std::get<BasicType>(rightExprType);
switch (node.op)
if (IsPrimitiveType(leftExprType))
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == BasicType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
[[fallthrough]];
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
PrimitiveType leftType = std::get<PrimitiveType>(leftExprType);
switch (node.op)
{
switch (leftType)
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == PrimitiveType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
[[fallthrough]];
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
case BasicType::Float1:
case BasicType::Int1:
switch (leftType)
{
if (GetComponentType(rightType) != leftType)
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
{
if (leftType != rightType && rightType != GetComponentType(leftType))
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case BasicType::Mat4x4:
{
switch (rightType)
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
{
case BasicType::Float1:
case BasicType::Float4:
case BasicType::Mat4x4:
break;
if (IsMatrixType(rightExprType))
TypeMustMatch(leftType, std::get<MatrixType>(rightExprType).type);
else if (IsVectorType(rightExprType))
TypeMustMatch(leftType, std::get<VectorType>(rightExprType).type);
else
throw AstError{ "incompatible types" };
default:
TypeMustMatch(node.left, node.right);
break;
}
break;
}
case PrimitiveType::Boolean:
throw AstError{ "this operation is not supported for booleans" };
default:
TypeMustMatch(node.left, node.right);
break;
default:
throw AstError{ "incompatible types" };
}
}
}
}
else if (IsMatrixType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
switch (node.op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsMatrixType(rightExprType))
TypeMustMatch(leftExprType, rightExprType);
else if (IsPrimitiveType(rightExprType))
TypeMustMatch(leftType.type, rightExprType);
else if (IsVectorType(rightExprType))
{
const VectorType& rightType = std::get<VectorType>(rightExprType);
TypeMustMatch(leftType.type, rightType.type);
if (leftType.columnCount != rightType.componentCount)
throw AstError{ "incompatible types" };
}
else
throw AstError{ "incompatible types" };
}
}
}
else if (IsVectorType(leftExprType))
{
const MatrixType& leftType = std::get<MatrixType>(leftExprType);
switch (node.op)
{
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Add:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case BinaryType::Multiply:
case BinaryType::Divide:
{
if (IsPrimitiveType(rightExprType))
TypeMustMatch(leftType.type, rightExprType);
else
throw AstError{ "incompatible types" };
}
}
}
@@ -258,24 +305,35 @@ namespace Nz::ShaderAst
{
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
auto GetComponentCount = [](const ExpressionType& exprType) -> unsigned int
{
if (IsPrimitiveType(exprType))
return 1;
else if (IsVectorType(exprType))
return std::get<VectorType>(exprType).componentCount;
else
throw AstError{ "wut" };
};
unsigned int componentCount = 0;
unsigned int requiredComponents = GetComponentCount(node.targetType);
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
if (!IsBasicType(exprType))
ExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "incompatible type" };
componentCount += GetComponentCount(std::get<BasicType>(exprType));
componentCount += GetComponentCount(exprType);
}
if (componentCount != requiredComponents)
throw AstError{ "component count doesn't match required component count" };
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(ConditionalExpression& node)
@@ -313,34 +371,51 @@ namespace Nz::ShaderAst
{
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
case IntrinsicType::DotProduct:
{
if (node.parameters.size() != 2)
throw AstError { "Expected 2 parameters" };
throw AstError { "Expected two parameters" };
for (auto& param : node.parameters)
MandatoryExpr(param);
ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
ExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
for (std::size_t i = 1; i < node.parameters.size(); ++i)
{
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache)
if (type != GetExpressionType(MandatoryExpr(node.parameters[i]), m_context->cache))
throw AstError{ "All type must match" };
}
break;
}
case IntrinsicType::SampleTexture:
{
if (node.parameters.size() != 2)
throw AstError{ "Expected two parameters" };
for (auto& param : node.parameters)
MandatoryExpr(param);
if (!IsSamplerType(GetExpressionType(*node.parameters[0], m_context->cache)))
throw AstError{ "First parameter must be a sampler" };
if (!IsVectorType(GetExpressionType(*node.parameters[1], m_context->cache)))
throw AstError{ "First parameter must be a vector" };
}
}
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
{
if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache)
throw AstError{ "CrossProduct only works with Float3 expressions" };
if (GetExpressionType(*node.parameters[0]) != ExpressionType{ VectorType{ 3, PrimitiveType::Float32 } })
throw AstError{ "CrossProduct only works with vec3<f32> expressions" };
break;
}
@@ -348,8 +423,6 @@ namespace Nz::ShaderAst
case IntrinsicType::DotProduct:
break;
}
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(SwizzleExpression& node)
@@ -359,26 +432,10 @@ namespace Nz::ShaderAst
if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" };
ShaderExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
if (!IsBasicType(exprType))
ExpressionType exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
if (!IsPrimitiveType(exprType) && !IsVectorType(exprType))
throw AstError{ "Cannot swizzle this type" };
switch (std::get<BasicType>(exprType))
{
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
break;
default:
throw AstError{ "Cannot swizzle this type" };
}
AstRecursiveVisitor::Visit(node);
}
@@ -388,8 +445,8 @@ namespace Nz::ShaderAst
for (auto& condStatement : node.condStatements)
{
ShaderExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
if (!IsBasicType(condType) || std::get<BasicType>(condType) != BasicType::Boolean)
ExpressionType condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
if (!IsPrimitiveType(condType) || std::get<PrimitiveType>(condType) != PrimitiveType::Boolean)
throw AstError{ "if expression must resolve to boolean type" };
MandatoryStatement(condStatement.statement);
@@ -409,6 +466,78 @@ namespace Nz::ShaderAst
// throw AstError{ "condition not found" };
}
void AstValidator::Visit(DeclareExternalStatement& node)
{
RegisterScope(node);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
for (const auto& [attributeType, arg] : node.attributes)
{
switch (attributeType)
{
default:
throw AstError{ "unhandled attribute for external block" };
}
}
for (const auto& extVar : node.externalVars)
{
bool hasBinding = false;
bool hasLayout = false;
for (const auto& [attributeType, arg] : extVar.attributes)
{
switch (attributeType)
{
case AttributeType::Binding:
{
if (hasBinding)
throw AstError{ "attribute binding must be present once" };
if (!std::holds_alternative<long long>(arg))
throw AstError{ "attribute binding requires a string parameter" };
long long bindingIndex = std::get<long long>(arg);
if (m_context->usedBindingIndexes.find(bindingIndex) != m_context->usedBindingIndexes.end())
throw AstError{ "Binding #" + std::to_string(bindingIndex) + " is already in use" };
m_context->usedBindingIndexes.insert(bindingIndex);
break;
}
case AttributeType::Layout:
{
if (hasLayout)
throw AstError{ "attribute layout must be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute layout requires a string parameter" };
if (std::get<std::string>(arg) != "std140")
throw AstError{ "unknow layout type" };
hasLayout = true;
break;
}
default:
throw AstError{ "unhandled attribute for external variable" };
}
}
if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end())
throw AstError{ "External variable " + extVar.name + " is already declared" };
m_context->declaredExternalVar.insert(extVar.name);
ExpressionType subType = extVar.type;
if (IsUniformType(subType))
subType = IdentifierType{ std::get<UniformType>(subType).containedType };
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ extVar.name, AstCache::Variable { std::move(subType) } };
}
}
void AstValidator::Visit(DeclareFunctionStatement& node)
{
bool hasEntry = false;
@@ -421,12 +550,14 @@ namespace Nz::ShaderAst
if (hasEntry)
throw AstError{ "attribute entry must be present once" };
if (arg.empty())
throw AstError{ "attribute entry requires a parameter" };
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute entry requires a string parameter" };
auto it = entryPoints.find(arg);
const std::string& argStr = std::get<std::string>(arg);
auto it = entryPoints.find(argStr);
if (it == entryPoints.end())
throw AstError{ "invalid parameter " + arg + " for entry attribute" };
throw AstError{ "invalid parameter " + argStr + " for entry attribute" };
ShaderStageType stageType = it->second;
@@ -435,6 +566,9 @@ namespace Nz::ShaderAst
m_context->cache->entryFunctions[UnderlyingCast(it->second)] = &node;
if (node.parameters.size() > 1)
throw AstError{ "entry functions can either take one struct parameter or no parameter" };
hasEntry = true;
break;
}
@@ -468,6 +602,8 @@ namespace Nz::ShaderAst
RegisterScope(node);
//TODO: check members attributes
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();

View File

@@ -36,22 +36,24 @@ namespace Nz::ShaderLang
std::vector<Token> Tokenize(const std::string_view& str)
{
// Can't use std::from_chars for double thanks to libc++ and libstdc++ developers for being lazy
// Can't use std::from_chars for double, thanks to libc++ and libstdc++ developers for being lazy
ForceCLocale forceCLocale;
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "return", TokenType::Return },
{ "true", TokenType::BoolTrue }
{ "external", TokenType::External },
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "return", TokenType::Return },
{ "struct", TokenType::Struct },
{ "true", TokenType::BoolTrue }
};
std::size_t currentPos = 0;
auto Peek = [&](std::size_t advance = 1) -> char
{
if (currentPos + advance < str.size())
if (currentPos + advance < str.size() && str[currentPos + advance] != '\0')
return str[currentPos + advance];
else
return char(-1);
@@ -134,7 +136,10 @@ namespace Nz::ShaderLang
{
currentPos++;
if (Peek() == '/')
{
currentPos++;
break;
}
}
else if (next == '\n')
{
@@ -250,7 +255,48 @@ namespace Nz::ShaderLang
break;
}
case '=': tokenType = TokenType::Assign; break;
case '=':
{
char next = Peek();
if (next == '=')
{
currentPos++;
tokenType = TokenType::Equal;
}
else
tokenType = TokenType::Assign;
break;
}
case '<':
{
char next = Peek();
if (next == '=')
{
currentPos++;
tokenType = TokenType::LessThanEqual;
}
else
tokenType = TokenType::LessThan;
break;
}
case '>':
{
char next = Peek();
if (next == '=')
{
currentPos++;
tokenType = TokenType::GreatherThanEqual;
}
else
tokenType = TokenType::GreatherThan;
break;
}
case '+': tokenType = TokenType::Plus; break;
case '*': tokenType = TokenType::Multiply; break;
case ':': tokenType = TokenType::Colon; break;

View File

@@ -11,32 +11,24 @@ namespace Nz::ShaderLang
{
namespace
{
std::unordered_map<std::string, ShaderAst::BasicType> identifierToBasicType = {
{ "bool", ShaderAst::BasicType::Boolean },
std::unordered_map<std::string, ShaderAst::PrimitiveType> identifierToBasicType = {
{ "bool", ShaderAst::PrimitiveType::Boolean },
{ "i32", ShaderAst::PrimitiveType::Int32 },
{ "f32", ShaderAst::PrimitiveType::Float32 },
{ "u32", ShaderAst::PrimitiveType::UInt32 }
};
{ "i32", ShaderAst::BasicType::Int1 },
{ "vec2i32", ShaderAst::BasicType::Int2 },
{ "vec3i32", ShaderAst::BasicType::Int3 },
{ "vec4i32", ShaderAst::BasicType::Int4 },
{ "f32", ShaderAst::BasicType::Float1 },
{ "vec2f32", ShaderAst::BasicType::Float2 },
{ "vec3f32", ShaderAst::BasicType::Float3 },
{ "vec4f32", ShaderAst::BasicType::Float4 },
{ "mat4x4f32", ShaderAst::BasicType::Mat4x4 },
{ "sampler2D", ShaderAst::BasicType::Sampler2D },
{ "void", ShaderAst::BasicType::Void },
{ "u32", ShaderAst::BasicType::UInt1 },
{ "vec2u32", ShaderAst::BasicType::UInt3 },
{ "vec3u32", ShaderAst::BasicType::UInt3 },
{ "vec4u32", ShaderAst::BasicType::UInt4 },
std::unordered_map<std::string, ShaderAst::IntrinsicType> identifierToIntrinsic = {
{ "cross", ShaderAst::IntrinsicType::CrossProduct },
{ "dot", ShaderAst::IntrinsicType::DotProduct },
};
std::unordered_map<std::string, ShaderAst::AttributeType> identifierToAttributeType = {
{ "entry", ShaderAst::AttributeType::Entry },
{ "layout", ShaderAst::AttributeType::Layout },
{ "binding", ShaderAst::AttributeType::Binding },
{ "builtin", ShaderAst::AttributeType::Builtin },
{ "entry", ShaderAst::AttributeType::Entry },
{ "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
};
}
@@ -50,22 +42,41 @@ namespace Nz::ShaderLang
m_context = &context;
std::vector<ShaderAst::Attribute> attributes;
EnterScope();
bool reachedEndOfStream = false;
while (!reachedEndOfStream)
{
const Token& nextToken = Peek();
switch (nextToken.type)
{
case TokenType::EndOfStream:
if (!attributes.empty())
throw UnexpectedToken{};
reachedEndOfStream = true;
break;
case TokenType::External:
context.root->statements.push_back(ParseExternalBlock(std::move(attributes)));
attributes.clear();
break;
case TokenType::OpenAttribute:
HandleAttributes();
assert(attributes.empty());
attributes = ParseAttributes();
break;
case TokenType::FunctionDeclaration:
context.root->statements.push_back(ParseFunctionDeclaration());
context.root->statements.push_back(ParseFunctionDeclaration(std::move(attributes)));
attributes.clear();
break;
case TokenType::EndOfStream:
reachedEndOfStream = true;
case TokenType::Struct:
context.root->statements.push_back(ParseStructDeclaration(std::move(attributes)));
attributes.clear();
break;
default:
@@ -73,6 +84,8 @@ namespace Nz::ShaderLang
}
}
LeaveScope();
return std::move(context.root);
}
@@ -90,6 +103,92 @@ namespace Nz::ShaderLang
m_context->tokenIndex += count;
}
ShaderAst::ExpressionType Parser::DecodeType(const std::string& identifier)
{
if (auto it = identifierToBasicType.find(identifier); it != identifierToBasicType.end())
return it->second;
//FIXME: Handle this better
if (identifier == "mat4")
{
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 4;
matrixType.rowCount = 4;
Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return matrixType;
}
else if (identifier == "sampler2D")
{
ShaderAst::SamplerType samplerType;
samplerType.dim = ImageType_2D;
Expect(Advance(), TokenType::LessThan); //< '<'
samplerType.sampledType = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return samplerType;
}
else if (identifier == "uniform")
{
ShaderAst::UniformType uniformType;
Expect(Advance(), TokenType::LessThan); //< '<'
uniformType.containedType = ShaderAst::IdentifierType{ ParseIdentifierAsName() };
Expect(Advance(), TokenType::GreatherThan); //< '>'
return uniformType;
}
else if (identifier == "vec2")
{
ShaderAst::VectorType vectorType;
vectorType.componentCount = 2;
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return vectorType;
}
else if (identifier == "vec3")
{
ShaderAst::VectorType vectorType;
vectorType.componentCount = 3;
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return vectorType;
}
else if (identifier == "vec4")
{
ShaderAst::VectorType vectorType;
vectorType.componentCount = 4;
Expect(Advance(), TokenType::LessThan); //< '<'
vectorType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreatherThan); //< '>'
return vectorType;
}
else
{
ShaderAst::IdentifierType identifierType;
identifierType.name = identifier;
return identifierType;
}
}
void Parser::EnterScope()
{
m_context->scopeSizes.push_back(m_context->identifiersInScope.size());
}
const Token& Parser::Expect(const Token& token, TokenType type)
{
if (token.type != type)
@@ -114,13 +213,34 @@ namespace Nz::ShaderLang
return token;
}
void Parser::LeaveScope()
{
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.resize(m_context->scopeSizes.back());
m_context->scopeSizes.pop_back();
}
bool Parser::IsVariableInScope(const std::string_view& identifier) const
{
return std::find(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), identifier) != m_context->identifiersInScope.rend();
}
void Parser::RegisterVariable(std::string identifier)
{
if (IsVariableInScope(identifier))
throw DuplicateIdentifier{ ("identifier name " + identifier + " is already taken").c_str() };
assert(!m_context->scopeSizes.empty());
m_context->identifiersInScope.push_back(std::move(identifier));
}
const Token& Parser::Peek(std::size_t advance)
{
assert(m_context->tokenIndex + advance < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + advance];
}
void Parser::HandleAttributes()
std::vector<ShaderAst::Attribute> Parser::ParseAttributes()
{
std::vector<ShaderAst::Attribute> attributes;
@@ -150,13 +270,22 @@ namespace Nz::ShaderLang
ShaderAst::AttributeType attributeType = ParseIdentifierAsAttributeType();
std::string arg;
ShaderAst::Attribute::Param arg;
if (Peek().type == TokenType::OpenParenthesis)
{
Consume();
if (Peek().type == TokenType::Identifier)
arg = std::get<std::string>(Advance().data);
const Token& n = Peek();
if (n.type == TokenType::Identifier)
{
arg = std::get<std::string>(n.data);
Consume();
}
else if (n.type == TokenType::IntegerValue)
{
arg = std::get<long long>(n.data);
Consume();
}
Expect(Advance(), TokenType::ClosingParenthesis);
}
@@ -171,16 +300,54 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingAttribute);
const Token& nextToken = Peek();
switch (nextToken.type)
return attributes;
}
ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::External);
Expect(Advance(), TokenType::OpenCurlyBracket);
std::unique_ptr<ShaderAst::DeclareExternalStatement> externalStatement = std::make_unique<ShaderAst::DeclareExternalStatement>();
externalStatement->attributes = std::move(attributes);
bool first = true;
for (;;)
{
case TokenType::FunctionDeclaration:
m_context->root->statements.push_back(ParseFunctionDeclaration(std::move(attributes)));
if (!first)
{
const Token& nextToken = Peek();
if (nextToken.type == TokenType::Comma)
Consume();
else
{
Expect(nextToken, TokenType::ClosingCurlyBracket);
break;
}
}
first = false;
const Token& token = Peek();
if (token.type == TokenType::ClosingCurlyBracket)
break;
default:
throw UnexpectedToken{};
auto& extVar = externalStatement->externalVars.emplace_back();
if (token.type == TokenType::OpenAttribute)
extVar.attributes = ParseAttributes();
extVar.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon);
extVar.type = ParseType();
RegisterVariable(extVar.name);
}
Expect(Advance(), TokenType::ClosingCurlyBracket);
return externalStatement;
}
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
@@ -216,17 +383,23 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::ClosingParenthesis);
ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void;
ShaderAst::ExpressionType returnType;
if (Peek().type == TokenType::FunctionReturn)
{
Consume();
returnType = ParseIdentifierAsType();
returnType = ParseType();
}
Expect(Advance(), TokenType::OpenCurlyBracket);
EnterScope();
for (const auto& parameter : parameters)
RegisterVariable(parameter.name);
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
LeaveScope();
Expect(Advance(), TokenType::ClosingCurlyBracket);
return ShaderBuilder::DeclareFunction(std::move(attributes), std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
@@ -238,11 +411,59 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Colon);
ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType();
ShaderAst::ExpressionType parameterType = ParseType();
return { parameterName, parameterType };
}
ShaderAst::StatementPtr Parser::ParseStructDeclaration(std::vector<ShaderAst::Attribute> attributes)
{
Expect(Advance(), TokenType::Struct);
ShaderAst::StructDescription description;
description.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::OpenCurlyBracket);
bool first = true;
for (;;)
{
if (!first)
{
const Token& nextToken = Peek();
if (nextToken.type == TokenType::Comma)
Consume();
else
{
Expect(nextToken, TokenType::ClosingCurlyBracket);
break;
}
}
first = false;
const Token& token = Peek();
if (token.type == TokenType::ClosingCurlyBracket)
break;
auto& structField = description.members.emplace_back();
if (token.type == TokenType::OpenAttribute)
structField.attributes = ParseAttributes();
structField.name = ParseIdentifierAsName();
Expect(Advance(), TokenType::Colon);
structField.type = ParseType();
}
Expect(Advance(), TokenType::ClosingCurlyBracket);
return ShaderBuilder::DeclareStruct(std::move(attributes), std::move(description));
}
ShaderAst::StatementPtr Parser::ParseReturnStatement()
{
Expect(Advance(), TokenType::Return);
@@ -265,6 +486,10 @@ namespace Nz::ShaderLang
statement = ParseVariableDeclaration();
break;
case TokenType::Identifier:
statement = ShaderBuilder::ExpressionStatement(ParseVariableAssignation());
break;
case TokenType::Return:
statement = ParseReturnStatement();
break;
@@ -290,15 +515,26 @@ namespace Nz::ShaderLang
return statements;
}
ShaderAst::ExpressionPtr Parser::ParseVariableAssignation()
{
ShaderAst::ExpressionPtr left = ParseIdentifier();
Expect(Advance(), TokenType::Assign);
ShaderAst::ExpressionPtr right = ParseExpression();
return ShaderBuilder::Assign(ShaderAst::AssignType::Simple, std::move(left), std::move(right));
}
ShaderAst::StatementPtr Parser::ParseVariableDeclaration()
{
Expect(Advance(), TokenType::Let);
std::string variableName = ParseIdentifierAsName();
RegisterVariable(variableName);
Expect(Advance(), TokenType::Colon);
ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType();
ShaderAst::ExpressionType variableType = ParseType();
ShaderAst::ExpressionPtr expression;
if (Peek().type == TokenType::Assign)
@@ -351,18 +587,61 @@ namespace Nz::ShaderLang
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
ShaderAst::ExpressionPtr Parser::ParseFloatingPointExpression(bool minus)
{
const Token& floatingPointToken = Expect(Advance(), TokenType::FloatingPointValue);
return ShaderBuilder::Constant(((minus) ? -1.f : 1.f) * float(std::get<double>(floatingPointToken.data))); //< FIXME
}
ShaderAst::ExpressionPtr Parser::ParseIdentifier()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
return ShaderBuilder::Identifier(identifier);
ShaderAst::ExpressionPtr identifierExpr = ShaderBuilder::Identifier(identifier);
if (Peek().type == TokenType::Dot)
{
std::unique_ptr<ShaderAst::AccessMemberExpression> accessMemberNode = std::make_unique<ShaderAst::AccessMemberExpression>();
accessMemberNode->structExpr = std::move(identifierExpr);
do
{
Consume();
accessMemberNode->memberIdentifiers.push_back(ParseIdentifierAsName());
} while (Peek().type == TokenType::Dot);
identifierExpr = std::move(accessMemberNode);
}
return identifierExpr;
}
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression()
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression(bool minus)
{
const Token& integerToken = Expect(Advance(), TokenType::Identifier);
return ShaderBuilder::Constant(static_cast<Nz::Int32>(std::get<long long>(integerToken.data)));
const Token& integerToken = Expect(Advance(), TokenType::IntegerValue);
return ShaderBuilder::Constant(((minus) ? -1 : 1) * static_cast<Nz::Int32>(std::get<long long>(integerToken.data)));
}
std::vector<ShaderAst::ExpressionPtr> Parser::ParseParameters()
{
Expect(Advance(), TokenType::OpenParenthesis);
std::vector<ShaderAst::ExpressionPtr> parameters;
bool first = true;
while (Peek().type != TokenType::ClosingParenthesis)
{
if (!first)
Expect(Advance(), TokenType::Comma);
first = false;
parameters.push_back(ParseExpression());
}
Expect(Advance(), TokenType::ClosingParenthesis);
return parameters;
}
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
@@ -388,15 +667,69 @@ namespace Nz::ShaderLang
return ShaderBuilder::Constant(true);
case TokenType::FloatingPointValue:
Consume();
return ShaderBuilder::Constant(float(std::get<double>(token.data))); //< FIXME
return ParseFloatingPointExpression();
case TokenType::Identifier:
return ParseIdentifier();
{
const std::string& identifier = std::get<std::string>(token.data);
if (auto it = identifierToIntrinsic.find(identifier); it != identifierToIntrinsic.end())
{
if (Peek(1).type == TokenType::OpenParenthesis)
{
Consume();
return ShaderBuilder::Intrinsic(it->second, ParseParameters());
}
}
if (IsVariableInScope(identifier))
{
auto node = ParseIdentifier();
if (node->GetType() == ShaderAst::NodeType::AccessMemberExpression)
{
ShaderAst::AccessMemberExpression* memberExpr = static_cast<ShaderAst::AccessMemberExpression*>(node.get());
if (!memberExpr->memberIdentifiers.empty() && memberExpr->memberIdentifiers.front() == "Sample")
{
if (Peek().type == TokenType::OpenParenthesis)
{
auto parameters = ParseParameters();
parameters.insert(parameters.begin(), std::move(memberExpr->structExpr));
return ShaderBuilder::Intrinsic(ShaderAst::IntrinsicType::SampleTexture, std::move(parameters));
}
}
}
return node;
}
Consume();
ShaderAst::ExpressionType exprType = DecodeType(identifier);
return ShaderBuilder::Cast(std::move(exprType), ParseParameters());
}
case TokenType::IntegerValue:
return ParseIntegerExpression();
case TokenType::Minus:
//< FIXME: Handle this with an unary node
if (Peek(1).type == TokenType::FloatingPointValue)
{
Consume();
return ParseFloatingPointExpression(true);
}
else if (Peek(1).type == TokenType::IntegerValue)
{
Consume();
return ParseIntegerExpression(true);
}
else
throw UnexpectedToken{};
break;
case TokenType::OpenParenthesis:
return ParseParenthesisExpression();
@@ -429,7 +762,7 @@ namespace Nz::ShaderLang
return identifier;
}
ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType()
ShaderAst::PrimitiveType Parser::ParsePrimitiveType()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
@@ -441,6 +774,23 @@ namespace Nz::ShaderLang
return it->second;
}
ShaderAst::ExpressionType Parser::ParseType()
{
// Handle () as no type
if (Peek().type == TokenType::OpenParenthesis)
{
Consume();
Expect(Advance(), TokenType::ClosingParenthesis);
return ShaderAst::NoType{};
}
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
const std::string& identifier = std::get<std::string>(identifierToken.data);
return DecodeType(identifier);
}
int Parser::GetTokenPrecedence(TokenType token)
{
switch (token)
@@ -452,4 +802,5 @@ namespace Nz::ShaderLang
default: return -1;
}
}
}

View File

@@ -39,18 +39,18 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
{
ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsBasicType(resultExprType));
ShaderAst::ExpressionType resultExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsPrimitiveType(resultExprType));
ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache);
assert(IsBasicType(leftExprType));
ShaderAst::ExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left, m_cache);
assert(IsPrimitiveType(leftExprType));
ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache);
assert(IsBasicType(rightExprType));
ShaderAst::ExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right, m_cache);
assert(IsPrimitiveType(rightExprType));
ShaderAst::BasicType resultType = std::get<ShaderAst::BasicType>(resultExprType);
ShaderAst::BasicType leftType = std::get<ShaderAst::BasicType>(leftExprType);
ShaderAst::BasicType rightType = std::get<ShaderAst::BasicType>(rightExprType);
ShaderAst::PrimitiveType resultType = std::get<ShaderAst::PrimitiveType>(resultExprType);
ShaderAst::PrimitiveType leftType = std::get<ShaderAst::PrimitiveType>(leftExprType);
ShaderAst::PrimitiveType rightType = std::get<ShaderAst::PrimitiveType>(rightExprType);
UInt32 leftOperand = EvaluateExpression(node.left);
@@ -67,26 +67,26 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFAdd;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIAdd;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@@ -97,26 +97,26 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFSub;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpISub;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@@ -127,28 +127,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSDiv;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUDiv;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@@ -159,29 +159,29 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Boolean:
case ShaderAst::PrimitiveType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIEqual;
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
// break;
}
break;
@@ -191,28 +191,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdGreaterThan;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSGreaterThan;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUGreaterThan;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@@ -223,28 +223,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdGreaterThanEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSGreaterThanEqual;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpUGreaterThanEqual;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@@ -255,28 +255,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdLessThanEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSLessThanEqual;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpULessThanEqual;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@@ -287,28 +287,28 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdLessThan;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
return SpirvOp::OpSLessThan;
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpULessThan;
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
case ShaderAst::PrimitiveType::Boolean:
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
break;
}
@@ -319,29 +319,29 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Boolean:
case ShaderAst::PrimitiveType::Boolean:
return SpirvOp::OpLogicalNotEqual;
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
case ShaderAst::PrimitiveType::Float32:
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// case ShaderAst::PrimitiveType::Mat4x4:
return SpirvOp::OpFOrdNotEqual;
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpINotEqual;
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
// case ShaderAst::PrimitiveType::Sampler2D:
// case ShaderAst::PrimitiveType::Void:
// break;
}
break;
@@ -351,22 +351,22 @@ namespace Nz
{
switch (leftType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::PrimitiveType::Float32:
{
switch (rightType)
{
case ShaderAst::BasicType::Float1:
case ShaderAst::PrimitiveType::Float32:
return SpirvOp::OpFMul;
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
case ShaderAst::BasicType::Mat4x4:
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// swapOperands = true;
// return SpirvOp::OpVectorTimesScalar;
//
// case ShaderAst::PrimitiveType::Mat4x4:
// swapOperands = true;
// return SpirvOp::OpMatrixTimesScalar;
default:
break;
@@ -375,54 +375,54 @@ namespace Nz
break;
}
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
{
switch (rightType)
{
case ShaderAst::BasicType::Float1:
return SpirvOp::OpVectorTimesScalar;
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// {
// switch (rightType)
// {
// case ShaderAst::PrimitiveType::Float32:
// return SpirvOp::OpVectorTimesScalar;
//
// case ShaderAst::PrimitiveType::Float2:
// case ShaderAst::PrimitiveType::Float3:
// case ShaderAst::PrimitiveType::Float4:
// return SpirvOp::OpFMul;
//
// case ShaderAst::PrimitiveType::Mat4x4:
// return SpirvOp::OpVectorTimesMatrix;
//
// default:
// break;
// }
//
// break;
// }
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
return SpirvOp::OpFMul;
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpVectorTimesMatrix;
default:
break;
}
break;
}
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
case ShaderAst::PrimitiveType::Int32:
// case ShaderAst::PrimitiveType::Int2:
// case ShaderAst::PrimitiveType::Int3:
// case ShaderAst::PrimitiveType::Int4:
case ShaderAst::PrimitiveType::UInt32:
// case ShaderAst::PrimitiveType::UInt2:
// case ShaderAst::PrimitiveType::UInt3:
// case ShaderAst::PrimitiveType::UInt4:
return SpirvOp::OpIMul;
case ShaderAst::BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderAst::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderAst::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
default:
break;
}
break;
}
// case ShaderAst::PrimitiveType::Mat4x4:
// {
// switch (rightType)
// {
// case ShaderAst::PrimitiveType::Float32: return SpirvOp::OpMatrixTimesScalar;
// case ShaderAst::PrimitiveType::Float4: return SpirvOp::OpMatrixTimesVector;
// case ShaderAst::PrimitiveType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
//
// default:
// break;
// }
//
// break;
// }
default:
break;
@@ -501,10 +501,10 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
{
const ShaderAst::ShaderExpressionType& targetExprType = node.targetType;
assert(IsBasicType(targetExprType));
const ShaderAst::ExpressionType& targetExprType = node.targetType;
assert(IsPrimitiveType(targetExprType));
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
@@ -582,12 +582,12 @@ namespace Nz
{
case ShaderAst::IntrinsicType::DotProduct:
{
ShaderAst::ShaderExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache);
assert(IsBasicType(vecExprType));
ShaderAst::ExpressionType vecExprType = GetExpressionType(*node.parameters[0], m_cache);
assert(IsVectorType(vecExprType));
ShaderAst::BasicType vecType = std::get<ShaderAst::BasicType>(vecExprType);
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
UInt32 typeId = m_writer.GetTypeId(ShaderAst::GetComponentType(vecType));
UInt32 typeId = m_writer.GetTypeId(vecType.type);
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
@@ -626,10 +626,10 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
ShaderAst::ShaderExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsBasicType(targetExprType));
ShaderAst::ExpressionType targetExprType = ShaderAst::GetExpressionType(node, m_cache);
assert(IsPrimitiveType(targetExprType));
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = m_writer.AllocateResultId();

View File

@@ -535,7 +535,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector2f>) ? ShaderAst::BasicType::Float2 : ShaderAst::BasicType::Int2),
BuildType(ShaderAst::VectorType{ 2, (std::is_same_v<T, Vector2f>) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }),
{
BuildConstant(arg.x),
BuildConstant(arg.y)
@@ -545,7 +545,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector3f>) ? ShaderAst::BasicType::Float3 : ShaderAst::BasicType::Int3),
BuildType(ShaderAst::VectorType{ 3, (std::is_same_v<T, Vector3f>) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@@ -556,7 +556,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector4f>) ? ShaderAst::BasicType::Float4 : ShaderAst::BasicType::Int4),
BuildType(ShaderAst::VectorType{ 4, (std::is_same_v<T, Vector4f>) ? ShaderAst::PrimitiveType::Float32 : ShaderAst::PrimitiveType::Int32 }),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@@ -570,7 +570,7 @@ namespace Nz
}, value));
}
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ShaderExpressionType& retType, const std::vector<ShaderAst::ShaderExpressionType>& parameters) -> TypePtr
auto SpirvConstantCache::BuildFunctionType(const ShaderAst::ExpressionType& retType, const std::vector<ShaderAst::ExpressionType>& parameters) -> TypePtr
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
@@ -584,7 +584,7 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(Pointer{
BuildType(type),
@@ -592,85 +592,22 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::PrimitiveType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(Pointer{
BuildType(type),
storageClass
});
});
}
auto SpirvConstantCache::BuildType(const ShaderAst::BasicType& type) -> TypePtr
{
return std::make_shared<Type>([&]() -> AnyType
{
switch (type)
{
case ShaderAst::BasicType::Boolean:
return Bool{};
case ShaderAst::BasicType::Float1:
return Float{ 32 };
case ShaderAst::BasicType::Int1:
return Integer{ 32, true };
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
{
auto vecType = BuildType(ShaderAst::GetComponentType(type));
UInt32 componentCount = ShaderAst::GetComponentCount(type);
return Vector{ vecType, componentCount };
}
case ShaderAst::BasicType::Mat4x4:
return Matrix{ BuildType(ShaderAst::BasicType::Float4), 4u };
case ShaderAst::BasicType::UInt1:
return Integer{ 32, false };
case ShaderAst::BasicType::Void:
return Void{};
case ShaderAst::BasicType::Sampler2D:
{
auto imageType = Image{
{}, //< qualifier
{}, //< depth
{}, //< sampled
SpirvDim::Dim2D, //< dim
SpirvImageFormat::Unknown, //< format
BuildType(ShaderAst::BasicType::Float1), //< sampledType
false, //< arrayed,
false //< multisampled
};
return SampledImage{ std::make_shared<Type>(imageType) };
}
}
throw std::runtime_error("unexpected type");
}());
}
auto SpirvConstantCache::BuildType(const ShaderAst::ShaderExpressionType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::ExpressionType& type) -> TypePtr
{
return std::visit([&](auto&& arg) -> TypePtr
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderAst::BasicType>)
return BuildType(arg);
else if constexpr (std::is_same_v<T, std::string>)
return BuildType(arg);
/*else if constexpr (std::is_same_v<T, std::string>)
{
/*// Register struct members type
// Register struct members type
const auto& structs = shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
@@ -688,14 +625,77 @@ namespace Nz
sMembers.type = BuildType(shader, member.type);
}
return std::make_shared<Type>(std::move(sType));*/
return std::make_shared<Type>(std::move(sType));
return nullptr;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");*/
}, type);
}
auto SpirvConstantCache::BuildType(const ShaderAst::IdentifierType& type) -> TypePtr
{
throw std::runtime_error("unexpected type");
}
auto SpirvConstantCache::BuildType(const ShaderAst::PrimitiveType& type) -> TypePtr
{
return std::make_shared<Type>([&]() -> AnyType
{
switch (type)
{
case ShaderAst::PrimitiveType::Boolean:
return Bool{};
case ShaderAst::PrimitiveType::Float32:
return Float{ 32 };
case ShaderAst::PrimitiveType::Int32:
return Integer{ 32, true };
}
throw std::runtime_error("unexpected type");
}());
}
auto SpirvConstantCache::BuildType(const ShaderAst::MatrixType& type) -> TypePtr
{
return std::make_shared<Type>(
Matrix{
BuildType(ShaderAst::VectorType {
UInt32(type.rowCount), type.type
}),
UInt32(type.columnCount)
});
}
auto SpirvConstantCache::BuildType(const ShaderAst::NoType& type) -> TypePtr
{
return std::make_shared<Type>(Void{});
}
auto SpirvConstantCache::BuildType(const ShaderAst::SamplerType& type) -> TypePtr
{
//TODO
auto imageType = Image{
{}, //< qualifier
{}, //< depth
{}, //< sampled
SpirvDim::Dim2D, //< dim
SpirvImageFormat::Unknown, //< format
BuildType(ShaderAst::PrimitiveType::Float32), //< sampledType
false, //< arrayed,
false //< multisampled
};
return std::make_shared<Type>(SampledImage{ std::make_shared<Type>(imageType) });
}
auto SpirvConstantCache::BuildType(const ShaderAst::VectorType& type) -> TypePtr
{
return std::make_shared<Type>(Vector{ BuildType(type.type), UInt32(type.componentCount) });
}
void SpirvConstantCache::Write(const AnyConstant& constant, UInt32 resultId, SpirvSection& constants)
{
std::visit([&](auto&& arg)

View File

@@ -29,7 +29,7 @@ namespace Nz
{
public:
using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
@@ -81,7 +81,7 @@ namespace Nz
{
funcs.emplace_back(node);
std::vector<ShaderAst::ShaderExpressionType> parameterTypes;
std::vector<ShaderAst::ExpressionType> parameterTypes;
for (auto& parameter : node.parameters)
parameterTypes.push_back(parameter.type);
@@ -92,8 +92,17 @@ namespace Nz
void Visit(ShaderAst::DeclareStructStatement& node) override
{
for (auto& field : node.description.members)
m_constantCache.Register(*SpirvConstantCache::BuildType(field.type));
SpirvConstantCache::Structure sType;
sType.name = node.description.name;
for (const auto& [name, attribute, type] : node.description.members)
{
auto& sMembers = sType.members.emplace_back();
sMembers.name = name;
sMembers.type = SpirvConstantCache::BuildType(type);
}
m_constantCache.Register(SpirvConstantCache::Type{ std::move(sType) });
}
void Visit(ShaderAst::DeclareVariableStatement& node) override
@@ -137,26 +146,26 @@ namespace Nz
};
template<typename T>
constexpr ShaderAst::BasicType GetBasicType()
constexpr ShaderAst::PrimitiveType GetBasicType()
{
if constexpr (std::is_same_v<T, bool>)
return ShaderAst::BasicType::Boolean;
return ShaderAst::PrimitiveType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return(ShaderAst::BasicType::Float1);
return(ShaderAst::PrimitiveType::Float32);
else if constexpr (std::is_same_v<T, Int32>)
return(ShaderAst::BasicType::Int1);
return(ShaderAst::PrimitiveType::Int32);
else if constexpr (std::is_same_v<T, Vector2f>)
return(ShaderAst::BasicType::Float2);
return(ShaderAst::PrimitiveType::Float2);
else if constexpr (std::is_same_v<T, Vector3f>)
return(ShaderAst::BasicType::Float3);
return(ShaderAst::PrimitiveType::Float3);
else if constexpr (std::is_same_v<T, Vector4f>)
return(ShaderAst::BasicType::Float4);
return(ShaderAst::PrimitiveType::Float4);
else if constexpr (std::is_same_v<T, Vector2i32>)
return(ShaderAst::BasicType::Int2);
return(ShaderAst::PrimitiveType::Int2);
else if constexpr (std::is_same_v<T, Vector3i32>)
return(ShaderAst::BasicType::Int3);
return(ShaderAst::PrimitiveType::Int3);
else if constexpr (std::is_same_v<T, Vector4i32>)
return(ShaderAst::BasicType::Int4);
return(ShaderAst::PrimitiveType::Int4);
else
static_assert(AlwaysFalse<T>::value, "unhandled type");
}
@@ -394,7 +403,7 @@ namespace Nz
if (!state.functionBlocks.back().IsTerminated())
{
assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void));
assert(func.returnType == ShaderAst::ExpressionType{ ShaderAst::NoType{} });
state.functionBlocks.back().Append(SpirvOp::OpReturn);
}
@@ -537,12 +546,12 @@ namespace Nz
return it.value();
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
}
@@ -643,12 +652,12 @@ namespace Nz
return m_currentState->constantTypeCache.Register({ *BuildFunctionType(functionNode) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass)
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type)
UInt32 SpirvWriter::RegisterType(ShaderAst::ExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
@@ -662,7 +671,7 @@ namespace Nz
SpirvConstantCache::TypePtr SpirvWriter::BuildFunctionType(const ShaderAst::DeclareFunctionStatement& functionNode)
{
std::vector<ShaderAst::ShaderExpressionType> parameterTypes;
std::vector<ShaderAst::ExpressionType> parameterTypes;
parameterTypes.reserve(functionNode.parameters.size());
for (const auto& parameter : functionNode.parameters)