Shader: Add module statement

This commit is contained in:
Jérôme Leclercq
2022-03-01 19:36:18 +01:00
parent ad892dfb43
commit 99e07e6e1e
56 changed files with 418 additions and 123 deletions

View File

@@ -834,6 +834,7 @@ namespace Nz::ShaderAst
case PrimitiveType::Float32: optimized = PropagateSingleValueCast<float>(constantExpr); break;
case PrimitiveType::Int32: optimized = PropagateSingleValueCast<Int32>(constantExpr); break;
case PrimitiveType::UInt32: optimized = PropagateSingleValueCast<UInt32>(constantExpr); break;
case PrimitiveType::String: break;
}
}
}
@@ -866,7 +867,7 @@ namespace Nz::ShaderAst
if constexpr (std::is_same_v<T, NoValue>)
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
else if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32> || std::is_same_v<T, std::string>)
constantValues.push_back(arg);
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i32>)
{

View File

@@ -5,6 +5,7 @@
#include <Nazara/Shader/Ast/AstSerializer.hpp>
#include <Nazara/Shader/Ast/AstExpressionVisitor.hpp>
#include <Nazara/Shader/Ast/AstStatementVisitor.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
@@ -120,7 +121,7 @@ namespace Nz::ShaderAst
Value(value);
};
static_assert(std::variant_size_v<decltype(node.value)> == 11);
static_assert(std::variant_size_v<decltype(node.value)> == 12);
switch (typeIndex)
{
case 0: break;
@@ -134,6 +135,7 @@ namespace Nz::ShaderAst
case 8: SerializeValue(Vector2i32()); break;
case 9: SerializeValue(Vector3i32()); break;
case 10: SerializeValue(Vector4i32()); break;
case 11: SerializeValue(std::string()); break;
default: throw std::runtime_error("unexpected data type");
}
}
@@ -327,11 +329,12 @@ namespace Nz::ShaderAst
Node(node.body);
}
void ShaderAstSerializer::Serialize(StatementPtr& shader)
void ShaderAstSerializer::Serialize(Module& module)
{
m_stream << s_magicNumber << s_currentVersion;
Node(shader);
m_stream << module.shaderLangVersion;
Serialize(*module.rootNode);
m_stream.FlushBits();
}
@@ -514,7 +517,7 @@ namespace Nz::ShaderAst
m_stream << val;
}
StatementPtr ShaderAstUnserializer::Unserialize()
ModulePtr ShaderAstUnserializer::Unserialize()
{
UInt32 magicNumber;
UInt32 version;
@@ -526,13 +529,15 @@ namespace Nz::ShaderAst
if (version > s_currentVersion)
throw std::runtime_error("unsupported version");
StatementPtr node;
ModulePtr module = std::make_shared<Module>();
Node(node);
if (!node)
throw std::runtime_error("functions can only have statements");
m_stream >> module->shaderLangVersion;
return node;
module->rootNode = ShaderBuilder::MultiStatement();
ShaderSerializerVisitor visitor(*this);
module->rootNode->Visit(visitor);
return module;
}
bool ShaderAstUnserializer::IsWriting() const
@@ -827,18 +832,18 @@ namespace Nz::ShaderAst
}
ByteArray SerializeShader(StatementPtr& shader)
ByteArray SerializeShader(Module& module)
{
ByteArray byteArray;
ByteStream stream(&byteArray, OpenModeFlags(OpenMode::WriteOnly));
ShaderAstSerializer serializer(stream);
serializer.Serialize(shader);
serializer.Serialize(module);
return byteArray;
}
StatementPtr UnserializeShader(ByteStream& stream)
ModulePtr UnserializeShader(ByteStream& stream)
{
ShaderAstUnserializer unserializer(stream);
return unserializer.Unserialize();

View File

@@ -24,6 +24,8 @@ namespace Nz::ShaderAst
return PrimitiveType::Int32;
else if constexpr (std::is_same_v<T, UInt32>)
return PrimitiveType::UInt32;
else if constexpr (std::is_same_v<T, std::string>)
return PrimitiveType::String;
else if constexpr (std::is_same_v<T, Vector2f>)
return VectorType{ 2, PrimitiveType::Float32 };
else if constexpr (std::is_same_v<T, Vector3f>)

View File

@@ -112,9 +112,10 @@ namespace Nz::ShaderAst
std::vector<StatementPtr>* currentStatementList = nullptr;
};
StatementPtr SanitizeVisitor::Sanitize(Statement& statement, const Options& options, std::string* error)
ModulePtr SanitizeVisitor::Sanitize(Module& module, const Options& options, std::string* error)
{
StatementPtr clone;
ModulePtr clone = std::make_shared<Module>();
clone->shaderLangVersion = module.shaderLangVersion;
Context currentContext;
currentContext.options = options;
@@ -129,7 +130,7 @@ namespace Nz::ShaderAst
// First pass, evaluate everything except function code
try
{
clone = AstCloner::Clone(statement);
clone->rootNode = static_unique_pointer_cast<MultiStatement>(AstCloner::Clone(*module.rootNode));
}
catch (const AstError& err)
{

View File

@@ -152,7 +152,7 @@ namespace Nz
unsigned int indentLevel = 0;
};
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Statement& shader, const BindingMapping& bindingMapping, const States& states)
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::Module& module, const BindingMapping& bindingMapping, const States& states)
{
State state(bindingMapping);
state.optionValues = states.optionValues;
@@ -164,15 +164,15 @@ namespace Nz
m_currentState = nullptr;
});
ShaderAst::StatementPtr sanitizedAst;
ShaderAst::ModulePtr sanitizedModule;
ShaderAst::Statement* targetAst;
if (!states.sanitized)
{
sanitizedAst = Sanitize(shader, states.optionValues);
targetAst = sanitizedAst.get();
sanitizedModule = Sanitize(module, states.optionValues);
targetAst = sanitizedModule->rootNode.get();
}
else
targetAst = &shader;
targetAst = module.rootNode.get();
ShaderAst::StatementPtr optimizedAst;
@@ -210,7 +210,7 @@ namespace Nz
return s_flipYUniformName;
}
ShaderAst::StatementPtr GlslWriter::Sanitize(ShaderAst::Statement& ast, std::unordered_map<std::size_t, ShaderAst::ConstantValue> optionValues, std::string* error)
ShaderAst::ModulePtr GlslWriter::Sanitize(ShaderAst::Module& module, std::unordered_map<std::size_t, ShaderAst::ConstantValue> optionValues, std::string* error)
{
// Always sanitize for reserved identifiers
ShaderAst::SanitizeVisitor::Options options;
@@ -228,7 +228,7 @@ namespace Nz
"cross", "dot", "exp", "length", "max", "min", "pow", "texture"
};
return ShaderAst::Sanitize(ast, options, error);
return ShaderAst::Sanitize(module, options, error);
}
void GlslWriter::Append(const ShaderAst::ArrayType& /*type*/)
@@ -856,6 +856,8 @@ namespace Nz
if constexpr (std::is_same_v<T, ShaderAst::NoValue>)
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, std::string>)
throw std::runtime_error("unexpected string litteral");
else if constexpr (std::is_same_v<T, bool>)
Append((arg) ? "true" : "false");
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32>)

View File

@@ -76,6 +76,13 @@ namespace Nz
inline bool HasValue() const { return locationIndex.HasValue(); }
};
struct LangWriter::NzslAttribute
{
const ShaderAst::ExpressionValue<UInt32>& version;
inline bool HasValue() const { return version.HasValue(); }
};
struct LangWriter::SetAttribute
{
const ShaderAst::ExpressionValue<UInt32>& setIndex;
@@ -101,7 +108,7 @@ namespace Nz
unsigned int indentLevel = 0;
};
std::string LangWriter::Generate(ShaderAst::Statement& shader, const States& /*states*/)
std::string LangWriter::Generate(ShaderAst::Module& module, const States& /*states*/)
{
State state;
m_currentState = &state;
@@ -110,11 +117,11 @@ namespace Nz
m_currentState = nullptr;
});
ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader);
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module);
AppendHeader();
sanitizedAst->Visit(*this);
sanitizedModule->rootNode->Visit(*this);
return state.stream.str();
}
@@ -453,6 +460,10 @@ namespace Nz
Append(")");
}
void LangWriter::AppendAttribute(NzslAttribute nzslVersion)
{
}
void LangWriter::AppendAttribute(SetAttribute set)
{
if (!set.HasValue())
@@ -766,6 +777,8 @@ namespace Nz
Append((arg) ? "true" : "false");
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)
Append(std::to_string(arg));
else if constexpr (std::is_same_v<T, std::string>)
Append('"', arg, '"'); //< TODO: Escape string
else if constexpr (std::is_same_v<T, Vector2f>)
Append("vec2[f32](" + std::to_string(arg.x) + ", " + std::to_string(arg.y) + ")");
else if constexpr (std::is_same_v<T, Vector2i32>)

View File

@@ -51,6 +51,7 @@ namespace Nz::ShaderLang
{ "if", TokenType::If },
{ "in", TokenType::In },
{ "let", TokenType::Let },
{ "module", TokenType::Module },
{ "option", TokenType::Option },
{ "return", TokenType::Return },
{ "struct", TokenType::Struct },
@@ -65,7 +66,7 @@ namespace Nz::ShaderLang
if (currentPos + advance < str.size() && str[currentPos + advance] != '\0')
return str[currentPos + advance];
else
return char(-1);
return '\0';
};
auto IsAlphaNum = [&](const char c)
@@ -85,7 +86,7 @@ namespace Nz::ShaderLang
token.column = static_cast<unsigned int>(currentPos - lastLineFeed);
token.line = lineNumber;
if (c == -1)
if (c == '\0')
{
token.type = TokenType::EndOfStream;
tokens.push_back(std::move(token));
@@ -385,6 +386,55 @@ namespace Nz::ShaderLang
case '[': tokenType = TokenType::OpenSquareBracket; break;
case ']': tokenType = TokenType::ClosingSquareBracket; break;
case '"':
{
// string litteral
currentPos++;
std::string litteral;
char current;
while ((current = Peek(0)) != '"')
{
char character;
switch (current)
{
case '\0':
case '\n':
case '\r':
throw UnfinishedString{};
case '\\':
{
currentPos++;
char next = Peek();
switch (next)
{
case 'n': character = '\n'; break;
case 'r': character = '\r'; break;
case 't': character = '\t'; break;
case '"': character = '"'; break;
case '\\': character = '\\'; break;
default:
throw UnrecognizedChar{};
}
break;
}
default:
character = current;
break;
}
litteral.push_back(character);
currentPos++;
}
tokenType = TokenType::StringValue;
token.data = std::move(litteral);
break;
}
default:
{
if (IsAlphaNum(c))

View File

@@ -7,6 +7,7 @@
#include <Nazara/Core/File.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <cassert>
#include <regex>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderLang
@@ -29,6 +30,7 @@ namespace Nz::ShaderLang
{ "entry", ShaderAst::AttributeType::Entry },
{ "layout", ShaderAst::AttributeType::Layout },
{ "location", ShaderAst::AttributeType::Location },
{ "nzsl_version", ShaderAst::AttributeType::LangVersion },
{ "set", ShaderAst::AttributeType::Set },
{ "unroll", ShaderAst::AttributeType::Unroll },
};
@@ -106,18 +108,24 @@ namespace Nz::ShaderLang
}
}
ShaderAst::StatementPtr Parser::Parse(const std::vector<Token>& tokens)
ShaderAst::ModulePtr Parser::Parse(const std::vector<Token>& tokens)
{
Context context;
context.tokenCount = tokens.size();
context.tokens = tokens.data();
context.root = std::make_unique<ShaderAst::MultiStatement>();
m_context = &context;
std::vector<ShaderAst::ExprValue> attributes;
auto EnsureModule = [this]() -> ShaderAst::Module&
{
if (!m_context->module)
throw UnexpectedToken{ "unexpected token before module declaration" };
return *m_context->module;
};
bool reachedEndOfStream = false;
while (!reachedEndOfStream)
{
@@ -125,11 +133,14 @@ namespace Nz::ShaderLang
switch (nextToken.type)
{
case TokenType::Const:
{
if (!attributes.empty())
throw UnexpectedToken{};
context.root->statements.push_back(ParseConstStatement());
const auto& module = EnsureModule();
module.rootNode->statements.push_back(ParseConstStatement());
break;
}
case TokenType::EndOfStream:
if (!attributes.empty())
@@ -139,38 +150,58 @@ namespace Nz::ShaderLang
break;
case TokenType::External:
context.root->statements.push_back(ParseExternalBlock(std::move(attributes)));
{
const auto& module = EnsureModule();
module.rootNode->statements.push_back(ParseExternalBlock(std::move(attributes)));
attributes.clear();
break;
}
case TokenType::OpenSquareBracket:
assert(attributes.empty());
attributes = ParseAttributes();
break;
case TokenType::Module:
if (attributes.empty())
throw UnexpectedToken{};
ParseModuleStatement(std::move(attributes));
attributes.clear();
break;
case TokenType::Option:
{
if (!attributes.empty())
throw UnexpectedToken{};
context.root->statements.push_back(ParseOptionDeclaration());
const auto& module = EnsureModule();
module.rootNode->statements.push_back(ParseOptionDeclaration());
break;
}
case TokenType::FunctionDeclaration:
context.root->statements.push_back(ParseFunctionDeclaration(std::move(attributes)));
{
const auto& module = EnsureModule();
module.rootNode->statements.push_back(ParseFunctionDeclaration(std::move(attributes)));
attributes.clear();
break;
}
case TokenType::Struct:
context.root->statements.push_back(ParseStructDeclaration(std::move(attributes)));
{
const auto& module = EnsureModule();
module.rootNode->statements.push_back(ParseStructDeclaration(std::move(attributes)));
attributes.clear();
break;
}
default:
throw UnexpectedToken{};
}
}
return std::move(context.root);
return std::move(context.module);
}
const Token& Parser::Advance()
@@ -604,6 +635,74 @@ namespace Nz::ShaderLang
return { parameterName, std::move(parameterType) };
}
void Parser::ParseModuleStatement(std::vector<ShaderAst::ExprValue> attributes)
{
Expect(Advance(), TokenType::Module);
if (m_context->module)
throw DuplicateModule{ "you must set one module statement per file" };
std::optional<UInt32> moduleVersion;
for (auto&& [attributeType, arg] : attributes)
{
switch (attributeType)
{
case ShaderAst::AttributeType::LangVersion:
{
// Version parsing
if (moduleVersion.has_value())
throw AttributeError{ "attribute " + std::string("nzsl_version") + " must be present once" };
if (!arg)
throw AttributeError{ "attribute " + std::string("nzsl_version") + " requires a parameter"};
const ShaderAst::ExpressionPtr& expr = *arg;
if (expr->GetType() != ShaderAst::NodeType::ConstantValueExpression)
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
const std::string& versionStr = std::get<std::string>(constantValue.value);
std::regex versionRegex(R"(^(\d+)(\.(\d+)(\.(\d+))?)?$)", std::regex::ECMAScript);
std::smatch versionMatch;
if (!std::regex_match(versionStr, versionMatch, versionRegex))
throw AttributeError("invalid version for attribute nzsl");
assert(versionMatch.size() == 6);
std::uint32_t version = 0;
version += std::stoi(versionMatch[1]) * 100;
if (versionMatch.length(3) > 0)
version += std::stoi(versionMatch[3]) * 10;
if (versionMatch.length(5) > 0)
version += std::stoi(versionMatch[5]) * 1;
moduleVersion = version;
break;
}
default:
throw AttributeError{ "unhandled attribute for module" };
}
}
if (!moduleVersion.has_value())
throw AttributeError{ "missing module version" };
m_context->module = std::make_shared<ShaderAst::Module>();
m_context->module->rootNode = ShaderBuilder::MultiStatement();
m_context->module->shaderLangVersion = *moduleVersion;
Expect(Advance(), TokenType::Semicolon);
}
ShaderAst::StatementPtr Parser::ParseOptionDeclaration()
{
Expect(Advance(), TokenType::Option);
@@ -1132,11 +1231,20 @@ namespace Nz::ShaderLang
case TokenType::OpenParenthesis:
return ParseParenthesisExpression();
case TokenType::StringValue:
return ParseStringExpression();
default:
throw UnexpectedToken{};
}
}
ShaderAst::ExpressionPtr Parser::ParseStringExpression()
{
const Token& litteralToken = Expect(Advance(), TokenType::StringValue);
return ShaderBuilder::Constant(std::get<std::string>(litteralToken.data));
}
ShaderAst::AttributeType Parser::ParseIdentifierAsAttributeType()
{
const Token& identifierToken = Expect(Advance(), TokenType::Identifier);
@@ -1192,7 +1300,7 @@ namespace Nz::ShaderLang
}
}
ShaderAst::StatementPtr ParseFromFile(const std::filesystem::path& sourcePath)
ShaderAst::ModulePtr ParseFromFile(const std::filesystem::path& sourcePath)
{
File file(sourcePath);
if (!file.Open(OpenMode::ReadOnly | OpenMode::Text))

View File

@@ -465,6 +465,8 @@ namespace Nz
if constexpr (std::is_same_v<T, ShaderAst::NoValue>)
throw std::runtime_error("invalid type (value expected)");
else if constexpr (std::is_same_v<T, std::string>)
throw std::runtime_error("unexpected string litteral");
else if constexpr (std::is_same_v<T, bool>)
return ConstantBool{ arg };
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, Int32> || std::is_same_v<T, UInt32>)

View File

@@ -496,11 +496,10 @@ namespace Nz
{
}
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::Statement& shader, const States& states)
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::Module& module, const States& states)
{
ShaderAst::Statement* targetAst = &shader;
ShaderAst::StatementPtr sanitizedAst;
ShaderAst::ModulePtr sanitizedModule;
ShaderAst::Statement* targetAst;
if (!states.sanitized)
{
ShaderAst::SanitizeVisitor::Options options;
@@ -512,9 +511,11 @@ namespace Nz
options.splitMultipleBranches = true;
options.useIdentifierAccessesForStructs = false;
sanitizedAst = ShaderAst::Sanitize(shader, options);
targetAst = sanitizedAst.get();
sanitizedModule = ShaderAst::Sanitize(module, options);
targetAst = sanitizedModule->rootNode.get();
}
else
targetAst = module.rootNode.get();
ShaderAst::StatementPtr optimizedAst;
if (states.optimize)