Modules are workings \o/

This commit is contained in:
Jérôme Leclercq
2022-03-08 20:26:02 +01:00
parent 83d26e209e
commit be9bdc4705
29 changed files with 742 additions and 256 deletions

View File

@@ -31,73 +31,95 @@ namespace Nz
{
const ShaderAst::ExpressionValue<UInt32>& bindingIndex;
inline bool HasValue() const { return bindingIndex.HasValue(); }
bool HasValue() const { return bindingIndex.HasValue(); }
};
struct LangWriter::BuiltinAttribute
{
const ShaderAst::ExpressionValue<ShaderAst::BuiltinEntry>& builtin;
inline bool HasValue() const { return builtin.HasValue(); }
bool HasValue() const { return builtin.HasValue(); }
};
struct LangWriter::DepthWriteAttribute
{
const ShaderAst::ExpressionValue<ShaderAst::DepthWriteMode>& writeMode;
inline bool HasValue() const { return writeMode.HasValue(); }
bool HasValue() const { return writeMode.HasValue(); }
};
struct LangWriter::EarlyFragmentTestsAttribute
{
const ShaderAst::ExpressionValue<bool>& earlyFragmentTests;
inline bool HasValue() const { return earlyFragmentTests.HasValue(); }
bool HasValue() const { return earlyFragmentTests.HasValue(); }
};
struct LangWriter::EntryAttribute
{
const ShaderAst::ExpressionValue<ShaderStageType>& stageType;
inline bool HasValue() const { return stageType.HasValue(); }
bool HasValue() const { return stageType.HasValue(); }
};
struct LangWriter::LangVersionAttribute
{
UInt32 version;
bool HasValue() const { return true; }
};
struct LangWriter::LayoutAttribute
{
const ShaderAst::ExpressionValue<StructLayout>& layout;
inline bool HasValue() const { return layout.HasValue(); }
bool HasValue() const { return layout.HasValue(); }
};
struct LangWriter::LocationAttribute
{
const ShaderAst::ExpressionValue<UInt32>& locationIndex;
inline bool HasValue() const { return locationIndex.HasValue(); }
bool HasValue() const { return locationIndex.HasValue(); }
};
struct LangWriter::SetAttribute
{
const ShaderAst::ExpressionValue<UInt32>& setIndex;
inline bool HasValue() const { return setIndex.HasValue(); }
bool HasValue() const { return setIndex.HasValue(); }
};
struct LangWriter::UnrollAttribute
{
const ShaderAst::ExpressionValue<ShaderAst::LoopUnroll>& unroll;
inline bool HasValue() const { return unroll.HasValue(); }
bool HasValue() const { return unroll.HasValue(); }
};
struct LangWriter::UuidAttribute
{
Uuid uuid;
bool HasValue() const { return true; }
};
struct LangWriter::State
{
struct Identifier
{
std::size_t moduleIndex;
std::string name;
};
const States* states = nullptr;
ShaderAst::Module* module;
std::size_t currentModuleIndex;
std::stringstream stream;
std::unordered_map<std::size_t, std::string> constantNames;
std::unordered_map<std::size_t, ShaderAst::StructDescription*> structs;
std::unordered_map<std::size_t, std::string> variableNames;
std::unordered_map<std::size_t, Identifier> constantNames;
std::unordered_map<std::size_t, Identifier> structs;
std::unordered_map<std::size_t, Identifier> variableNames;
std::vector<std::string> moduleNames;
bool isInEntryPoint = false;
unsigned int indentLevel = 0;
};
@@ -116,6 +138,22 @@ namespace Nz
AppendHeader();
// Register imported modules
m_currentState->currentModuleIndex = 0;
for (const auto& importedModule : sanitizedModule->importedModules)
{
AppendAttributes(true, LangVersionAttribute{ importedModule.module->metadata->shaderLangVersion });
AppendAttributes(true, UuidAttribute{ importedModule.module->metadata->moduleId });
AppendLine("module ", importedModule.identifier);
EnterScope();
importedModule.module->rootNode->Visit(*this);
LeaveScope(true);
m_currentState->currentModuleIndex++;
m_currentState->moduleNames.push_back(importedModule.identifier);
}
m_currentState->currentModuleIndex = std::numeric_limits<std::size_t>::max();
sanitizedModule->rootNode->Visit(*this);
return state.stream.str();
@@ -213,8 +251,7 @@ namespace Nz
void LangWriter::Append(const ShaderAst::StructType& structType)
{
ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex);
Append(structDesc->name);
AppendIdentifier(m_currentState->structs, structType.structIndex);
}
void LangWriter::Append(const ShaderAst::Type& /*type*/)
@@ -292,31 +329,31 @@ namespace Nz
AppendAttributesInternal(first, secondParam, std::forward<Rest>(params)...);
}
void LangWriter::AppendAttribute(BindingAttribute binding)
void LangWriter::AppendAttribute(BindingAttribute attribute)
{
if (!binding.HasValue())
if (!attribute.HasValue())
return;
Append("binding(");
if (binding.bindingIndex.IsResultingValue())
Append(binding.bindingIndex.GetResultingValue());
if (attribute.bindingIndex.IsResultingValue())
Append(attribute.bindingIndex.GetResultingValue());
else
binding.bindingIndex.GetExpression()->Visit(*this);
attribute.bindingIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(BuiltinAttribute builtin)
void LangWriter::AppendAttribute(BuiltinAttribute attribute)
{
if (!builtin.HasValue())
if (!attribute.HasValue())
return;
Append("builtin(");
if (builtin.builtin.IsResultingValue())
if (attribute.builtin.IsResultingValue())
{
switch (builtin.builtin.GetResultingValue())
switch (attribute.builtin.GetResultingValue())
{
case ShaderAst::BuiltinEntry::FragCoord:
Append("fragcoord");
@@ -332,21 +369,21 @@ namespace Nz
}
}
else
builtin.builtin.GetExpression()->Visit(*this);
attribute.builtin.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite)
void LangWriter::AppendAttribute(DepthWriteAttribute attribute)
{
if (!depthWrite.HasValue())
if (!attribute.HasValue())
return;
Append("depth_write(");
if (depthWrite.writeMode.IsResultingValue())
if (attribute.writeMode.IsResultingValue())
{
switch (depthWrite.writeMode.GetResultingValue())
switch (attribute.writeMode.GetResultingValue())
{
case ShaderAst::DepthWriteMode::Greater:
Append("greater");
@@ -366,41 +403,41 @@ namespace Nz
}
}
else
depthWrite.writeMode.GetExpression()->Visit(*this);
attribute.writeMode.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests)
void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute attribute)
{
if (!earlyFragmentTests.HasValue())
if (!attribute.HasValue())
return;
Append("early_fragment_tests(");
if (earlyFragmentTests.earlyFragmentTests.IsResultingValue())
if (attribute.earlyFragmentTests.IsResultingValue())
{
if (earlyFragmentTests.earlyFragmentTests.GetResultingValue())
if (attribute.earlyFragmentTests.GetResultingValue())
Append("true");
else
Append("false");
}
else
earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this);
attribute.earlyFragmentTests.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(EntryAttribute entry)
void LangWriter::AppendAttribute(EntryAttribute attribute)
{
if (!entry.HasValue())
if (!attribute.HasValue())
return;
Append("entry(");
if (entry.stageType.IsResultingValue())
if (attribute.stageType.IsResultingValue())
{
switch (entry.stageType.GetResultingValue())
switch (attribute.stageType.GetResultingValue())
{
case ShaderStageType::Fragment:
Append("frag");
@@ -412,20 +449,39 @@ namespace Nz
}
}
else
entry.stageType.GetExpression()->Visit(*this);
attribute.stageType.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(LayoutAttribute entry)
void LangWriter::AppendAttribute(LangVersionAttribute attribute)
{
if (!entry.HasValue())
UInt32 shaderLangVersion = attribute.version;
UInt32 majorVersion = shaderLangVersion / 100;
shaderLangVersion -= majorVersion * 100;
UInt32 minorVersion = shaderLangVersion / 10;
shaderLangVersion -= minorVersion * 100;
UInt32 patchVersion = shaderLangVersion;
// nzsl_version
Append("nzsl_version(\"", majorVersion, ".", minorVersion);
if (patchVersion != 0)
Append(".", patchVersion);
Append("\")");
}
void LangWriter::AppendAttribute(LayoutAttribute attribute)
{
if (!attribute.HasValue())
return;
Append("layout(");
if (entry.layout.IsResultingValue())
if (attribute.layout.IsResultingValue())
{
switch (entry.layout.GetResultingValue())
switch (attribute.layout.GetResultingValue())
{
case StructLayout::Packed:
Append("packed");
@@ -437,50 +493,50 @@ namespace Nz
}
}
else
entry.layout.GetExpression()->Visit(*this);
attribute.layout.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(LocationAttribute location)
void LangWriter::AppendAttribute(LocationAttribute attribute)
{
if (!location.HasValue())
if (!attribute.HasValue())
return;
Append("location(");
if (location.locationIndex.IsResultingValue())
Append(location.locationIndex.GetResultingValue());
if (attribute.locationIndex.IsResultingValue())
Append(attribute.locationIndex.GetResultingValue());
else
location.locationIndex.GetExpression()->Visit(*this);
attribute.locationIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(SetAttribute set)
void LangWriter::AppendAttribute(SetAttribute attribute)
{
if (!set.HasValue())
if (!attribute.HasValue())
return;
Append("set(");
if (set.setIndex.IsResultingValue())
Append(set.setIndex.GetResultingValue());
if (attribute.setIndex.IsResultingValue())
Append(attribute.setIndex.GetResultingValue());
else
set.setIndex.GetExpression()->Visit(*this);
attribute.setIndex.GetExpression()->Visit(*this);
Append(")");
}
void LangWriter::AppendAttribute(UnrollAttribute unroll)
void LangWriter::AppendAttribute(UnrollAttribute attribute)
{
if (!unroll.HasValue())
if (!attribute.HasValue())
return;
Append("unroll(");
if (unroll.unroll.IsResultingValue())
if (attribute.unroll.IsResultingValue())
{
switch (unroll.unroll.GetResultingValue())
switch (attribute.unroll.GetResultingValue())
{
case ShaderAst::LoopUnroll::Always:
Append("always");
@@ -499,7 +555,12 @@ namespace Nz
}
}
else
unroll.unroll.GetExpression()->Visit(*this);
attribute.unroll.GetExpression()->Visit(*this);
}
void LangWriter::AppendAttribute(UuidAttribute attribute)
{
Append("uuid(\"", attribute.uuid.ToString(), "\")");
}
void LangWriter::AppendComment(const std::string& section)
@@ -539,6 +600,16 @@ namespace Nz
m_currentState->stream << txt << '\n' << std::string(m_currentState->indentLevel, '\t');
}
template<typename T>
void LangWriter::AppendIdentifier(const T& map, std::size_t id)
{
const auto& structIdentifier = Retrieve(map, id);
if (structIdentifier.moduleIndex != m_currentState->currentModuleIndex)
Append(m_currentState->moduleNames[structIdentifier.moduleIndex], '.');
Append(structIdentifier.name);
}
template<typename... Args>
void LangWriter::AppendLine(Args&&... params)
{
@@ -586,20 +657,32 @@ namespace Nz
void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(constantName);
assert(m_currentState->constantNames.find(constantIndex) == m_currentState->constantNames.end());
m_currentState->constantNames.emplace(constantIndex, std::move(constantName));
m_currentState->constantNames.emplace(constantIndex, std::move(identifier));
}
void LangWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc)
void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(structName);
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, desc);
m_currentState->structs.emplace(structIndex, std::move(identifier));
}
void LangWriter::RegisterVariable(std::size_t varIndex, std::string varName)
{
State::Identifier identifier;
identifier.moduleIndex = m_currentState->currentModuleIndex;
identifier.name = std::move(varName);
assert(m_currentState->variableNames.find(varIndex) == m_currentState->variableNames.end());
m_currentState->variableNames.emplace(varIndex, std::move(varName));
m_currentState->variableNames.emplace(varIndex, std::move(identifier));
}
void LangWriter::ScopeVisit(ShaderAst::Statement& node)
@@ -770,7 +853,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareAliasStatement& node)
{
throw std::runtime_error("TODO"); //< missing registering
//throw std::runtime_error("TODO"); //< missing registering
assert(node.aliasIndex);
@@ -828,7 +911,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::ConstantExpression& node)
{
Append(Retrieve(m_currentState->constantNames, node.constantId));
AppendIdentifier(m_currentState->constantNames, node.constantId);
}
void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node)
@@ -908,7 +991,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
assert(node.structIndex);
RegisterStruct(*node.structIndex, &node.description);
RegisterStruct(*node.structIndex, node.description.name);
AppendAttributes(true, LayoutAttribute{ node.description.layout });
Append("struct ");
@@ -1068,6 +1151,11 @@ namespace Nz
Append(")");
}
void LangWriter::Visit(ShaderAst::StructTypeExpression& node)
{
AppendIdentifier(m_currentState->structs, node.structTypeId);
}
void LangWriter::Visit(ShaderAst::MultiStatement& node)
{
if (!node.sectionName.empty())
@@ -1100,7 +1188,7 @@ namespace Nz
{
EnterScope();
node.statement->Visit(*this);
LeaveScope(true);
LeaveScope();
}
void LangWriter::Visit(ShaderAst::SwizzleExpression& node)
@@ -1115,8 +1203,7 @@ namespace Nz
void LangWriter::Visit(ShaderAst::VariableExpression& node)
{
const std::string& varName = Retrieve(m_currentState->variableNames, node.variableId);
Append(varName);
AppendIdentifier(m_currentState->variableNames, node.variableId);
}
void LangWriter::Visit(ShaderAst::UnaryExpression& node)
@@ -1150,22 +1237,7 @@ namespace Nz
void LangWriter::AppendHeader()
{
UInt32 shaderLangVersion = m_currentState->module->metadata->shaderLangVersion;
UInt32 majorVersion = shaderLangVersion / 100;
shaderLangVersion -= majorVersion * 100;
UInt32 minorVersion = shaderLangVersion / 10;
shaderLangVersion -= minorVersion * 100;
UInt32 patchVersion = shaderLangVersion;
// nzsl_version
Append("[nzsl_version(\"", majorVersion, ".", minorVersion);
if (patchVersion != 0)
Append(".", patchVersion);
AppendLine("\")]");
AppendAttributes(true, LangVersionAttribute{ m_currentState->module->metadata->shaderLangVersion });
AppendLine("module;");
AppendLine();
}