Shader: Remove the need of layout(std140) in external block

This commit is contained in:
Jérôme Leclercq 2021-04-12 18:52:48 +02:00
parent 854bb16320
commit 3499c1f92f
8 changed files with 144 additions and 125 deletions

View File

@ -19,7 +19,7 @@ struct Data
external
{
[[binding(0), layout(std140)]] viewerData: uniform<Data>,
[[binding(0)]] viewerData: uniform<Data>,
[[binding(1)]] tex: sampler2D<f32>
}

View File

@ -95,6 +95,7 @@ namespace Nz::ShaderAst
inline bool IsNoType(const ExpressionType& type);
inline bool IsPrimitiveType(const ExpressionType& type);
inline bool IsSamplerType(const ExpressionType& type);
inline bool IsStructType(const ExpressionType& type);
inline bool IsUniformType(const ExpressionType& type);
inline bool IsVectorType(const ExpressionType& type);
}

View File

@ -109,6 +109,11 @@ namespace Nz::ShaderAst
return std::holds_alternative<SamplerType>(type);
}
bool IsStructType(const ExpressionType& type)
{
return std::holds_alternative<StructType>(type);
}
bool IsUniformType(const ExpressionType& type)
{
return std::holds_alternative<UniformType>(type);

View File

@ -9,7 +9,8 @@
#include <Nazara/Prerequisites.hpp>
#include <Nazara/Shader/Config.hpp>
#include <Nazara/Shader/ShaderAstScopedVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <set>
#include <sstream>
@ -17,7 +18,7 @@
namespace Nz
{
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::AstScopedVisitor
class NAZARA_SHADER_API GlslWriter : public ShaderWriter, public ShaderAst::ExpressionVisitorExcept, public ShaderAst::StatementVisitorExcept
{
public:
struct Environment;
@ -59,24 +60,27 @@ namespace Nz
template<typename T1, typename T2, typename... Args> void Append(const T1& firstParam, const T2& secondParam, Args&&... params);
void AppendCommentSection(const std::string& section);
void AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader);
void AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers);
void AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers);
void AppendLine(const std::string& txt = {});
template<typename... Args> void AppendLine(Args&&... params);
void EnterScope();
void LeaveScope(bool skipLine = true);
void RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc);
void RegisterVariable(std::size_t varIndex, std::string varName);
void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false);
void Visit(ShaderAst::AccessMemberIdentifierExpression& node) override;
void Visit(ShaderAst::AccessMemberIndexExpression& node) override;
void Visit(ShaderAst::AssignExpression& node) override;
void Visit(ShaderAst::BinaryExpression& node) override;
void Visit(ShaderAst::CastExpression& node) override;
void Visit(ShaderAst::ConditionalExpression& node) override;
void Visit(ShaderAst::ConstantExpression& node) override;
void Visit(ShaderAst::IdentifierExpression& node) override;
void Visit(ShaderAst::IntrinsicExpression& node) override;
void Visit(ShaderAst::SwizzleExpression& node) override;
void Visit(ShaderAst::VariableExpression& node) override;
void Visit(ShaderAst::BranchStatement& node) override;
void Visit(ShaderAst::ConditionalStatement& node) override;

View File

@ -6,6 +6,7 @@
#include <Nazara/Core/MemoryView.hpp>
#include <Nazara/OpenGLRenderer/Utils.hpp>
#include <Nazara/Shader/GlslWriter.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/ShaderLangLexer.hpp>
#include <Nazara/Shader/ShaderLangParser.hpp>
@ -122,7 +123,9 @@ namespace Nz
if (!shader.Create(device, ToOpenGL(shaderStage)))
throw std::runtime_error("failed to create shader"); //< TODO: Handle error message
std::string code = writer.Generate(shaderStage, shaderAst, states);
ShaderAst::AstCloner cloner; //< FIXME: Required because writer may update AST
ShaderAst::StatementPtr clonedAst = cloner.Clone(shaderAst);
std::string code = writer.Generate(shaderStage, clonedAst, states);
shader.SetSource(code.data(), code.size());
shader.Compile();

View File

@ -10,6 +10,7 @@
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
#include <optional>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
@ -70,40 +71,6 @@ namespace Nz
ShaderAst::DeclareFunctionStatement* entryPoint = nullptr;
};
struct EntryFuncResolver : ShaderAst::AstScopedVisitor
{
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
if (&node != entryPoint)
return;
assert(node.parameters.size() == 1);
const ShaderAst::ExpressionType& inputType = node.parameters.front().type;
const ShaderAst::ExpressionType& outputType = node.returnType;
const Identifier* identifier;
assert(IsIdentifierType(node.parameters.front().type));
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(inputType).name);
assert(identifier);
inputIdentifier = *identifier;
assert(IsIdentifierType(outputType));
identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(outputType).name);
assert(identifier);
outputIdentifier = *identifier;
}
Identifier inputIdentifier;
Identifier outputIdentifier;
ShaderAst::DeclareFunctionStatement* entryPoint;
};
struct Builtin
{
std::string identifier;
@ -118,9 +85,17 @@ namespace Nz
struct GlslWriter::State
{
struct StructInfo
{
ShaderAst::StructDescription structDesc;
bool isStd140 = false;
};
const States* states = nullptr;
ShaderAst::DeclareFunctionStatement* entryFunc = nullptr;
std::stringstream stream;
std::unordered_map<std::size_t, StructInfo> structs;
std::unordered_map<std::size_t, std::string> variableNames;
unsigned int indentLevel = 0;
};
@ -143,10 +118,13 @@ namespace Nz
if (!ShaderAst::ValidateAst(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
ShaderAst::TransformVisitor transformVisitor;
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
PreVisitor previsitor;
previsitor.selectedEntryPoint = shaderStage;
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(shader);
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader);
if (!previsitor.entryPoint)
throw std::runtime_error("missing entry point");
@ -233,14 +211,10 @@ namespace Nz
AppendLine();
}
PushScope();
{
adaptedShader->Visit(*this);
adaptedShader->Visit(*this);
// Append true GLSL entry point
AppendEntryPoint(shaderStage, adaptedShader);
}
PopScope();
// Append true GLSL entry point
AppendEntryPoint(shaderStage, adaptedShader);
return state.stream.str();
}
@ -275,7 +249,7 @@ namespace Nz
void GlslWriter::Append(const ShaderAst::IdentifierType& identifierType)
{
Append(identifierType.name);
throw std::runtime_error("unexpected identifier type");
}
void GlslWriter::Append(const ShaderAst::MatrixType& matrixType)
@ -332,7 +306,8 @@ namespace Nz
void GlslWriter::Append(const ShaderAst::StructType& structType)
{
throw std::runtime_error("unexpected struct type");
const auto& structDesc = m_currentState->structs[structType.structIndex].structDesc;
Append(structDesc.name);
}
void GlslWriter::Append(const ShaderAst::UniformType& uniformType)
@ -395,9 +370,24 @@ namespace Nz
void GlslWriter::AppendEntryPoint(ShaderStageType shaderStage, ShaderAst::StatementPtr& shader)
{
EntryFuncResolver entryResolver;
entryResolver.entryPoint = m_currentState->entryFunc;
entryResolver.ScopedVisit(shader);
ShaderAst::DeclareFunctionStatement& entryFunc = *m_currentState->entryFunc;
std::optional<std::size_t> inputStructIndex;
if (!entryFunc.parameters.empty())
{
assert(entryFunc.parameters.size() == 1);
auto& parameter = entryFunc.parameters.front();
assert(std::holds_alternative<ShaderAst::StructType>(parameter.type));
inputStructIndex = std::get<ShaderAst::StructType>(parameter.type).structIndex;
}
std::optional<std::size_t> outputStructIndex;
if (!IsNoType(entryFunc.returnType))
{
assert(std::holds_alternative<ShaderAst::StructType>(entryFunc.returnType));
outputStructIndex = std::get<ShaderAst::StructType>(entryFunc.returnType).structIndex;
}
AppendLine();
AppendLine("// Entry point handling");
@ -411,12 +401,11 @@ namespace Nz
std::vector<InOutField> inputFields;
const ShaderAst::StructDescription* inputStruct = nullptr;
auto HandleInOutStructs = [this, shaderStage](const Identifier& identifier, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
auto HandleInOutStructs = [this, shaderStage](std::size_t structIndex, std::vector<InOutField>& fields, const char* keyword, const char* fromPrefix, const char* targetPrefix) -> const ShaderAst::StructDescription*
{
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier.value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier.value);
const auto& structDesc = m_currentState->structs[structIndex].structDesc;
for (const auto& member : s.members)
for (const auto& member : structDesc.members)
{
bool skip = false;
std::optional<std::string> builtinName;
@ -474,16 +463,16 @@ namespace Nz
}
AppendLine();
return &s;
return &structDesc;
};
if (!m_currentState->entryFunc->parameters.empty())
inputStruct = HandleInOutStructs(entryResolver.inputIdentifier, inputFields, "in", "_nzInput.", "_NzIn_");
if (inputStructIndex)
inputStruct = HandleInOutStructs(*inputStructIndex, inputFields, "in", "_nzInput.", "_NzIn_");
std::vector<InOutField> outputFields;
const ShaderAst::StructDescription* outputStruct = nullptr;
if (!IsNoType(m_currentState->entryFunc->returnType))
outputStruct = HandleInOutStructs(entryResolver.outputIdentifier, outputFields, "out", "_nzOutput.", "_NzOut_");
if (outputStructIndex)
outputStruct = HandleInOutStructs(*outputStructIndex, outputFields, "out", "_nzOutput.", "_NzOut_");
if (shaderStage == ShaderStageType::Vertex && m_environment.flipYPosition)
AppendLine("uniform float ", flipYUniformName, ";");
@ -533,24 +522,20 @@ namespace Nz
LeaveScope();
}
void GlslWriter::AppendField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
void GlslWriter::AppendField(std::size_t structIndex, const std::size_t* memberIndices, std::size_t remainingMembers)
{
const auto& structDesc = m_currentState->structs[structIndex].structDesc;
const auto& member = structDesc.members[*memberIndices];
Append(".");
Append(memberIdentifier[0]);
const Identifier* identifier = FindIdentifier(structName);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
assert(memberIt != s.members.end());
const auto& member = *memberIt;
Append(member.name);
if (remainingMembers > 1)
AppendField(std::get<ShaderAst::IdentifierType>(member.type).name, memberIdentifier + 1, remainingMembers - 1);
{
assert(IsStructType(member.type));
AppendField(std::get<ShaderAst::StructType>(member.type).structIndex, memberIndices + 1, remainingMembers - 1);
}
}
void GlslWriter::AppendLine(const std::string& txt)
@ -588,6 +573,21 @@ namespace Nz
Append("}");
}
void GlslWriter::RegisterStruct(std::size_t structIndex, bool isStd140, ShaderAst::StructDescription desc)
{
assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end());
m_currentState->structs.emplace(structIndex, State::StructInfo{
std::move(desc),
isStd140
});
}
void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName)
{
assert(m_currentState->variableNames.find(varIndex) == m_currentState->variableNames.end());
m_currentState->variableNames.emplace(varIndex, std::move(varName));
}
void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
{
bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue);
@ -601,14 +601,14 @@ namespace Nz
Append(")");
}
void GlslWriter::Visit(ShaderAst::AccessMemberIdentifierExpression& node)
void GlslWriter::Visit(ShaderAst::AccessMemberIndexExpression& node)
{
Visit(node.structExpr, true);
const ShaderAst::ExpressionType& exprType = node.structExpr->cachedExpressionType.value();
assert(IsIdentifierType(exprType));
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.structExpr);
assert(IsStructType(exprType));
AppendField(std::get<ShaderAst::IdentifierType>(exprType).name, node.memberIdentifiers.data(), node.memberIdentifiers.size());
AppendField(std::get<ShaderAst::StructType>(exprType).structIndex, node.memberIndices.data(), node.memberIndices.size());
}
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
@ -638,9 +638,7 @@ namespace Nz
AppendLine(")");
EnterScope();
PushScope();
statement.statement->Visit(*this);
PopScope();
LeaveScope();
first = false;
@ -651,9 +649,7 @@ namespace Nz
AppendLine("else");
EnterScope();
PushScope();
node.elseStatement->Visit(*this);
PopScope();
LeaveScope();
}
}
@ -747,21 +743,32 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareExternalStatement& node)
{
assert(node.varIndex);
std::size_t varIndex = *node.varIndex;
for (const auto& externalVar : node.externalVars)
{
std::optional<long long> bindingIndex;
bool isStd140 = false;
for (const auto& [attributeType, attributeParam] : externalVar.attributes)
{
if (attributeType == ShaderAst::AttributeType::Binding)
bindingIndex = std::get<long long>(attributeParam);
else if (attributeType == ShaderAst::AttributeType::Layout)
{
if (std::get<std::string>(attributeParam) == "std140")
isStd140 = true;
bindingIndex = std::get<long long>(attributeParam);
break;
}
}
bool isStd140 = false;
if (IsUniformType(externalVar.type))
{
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = m_currentState->structs[structIndex];
isStd140 = structInfo.isStd140;
}
if (bindingIndex)
{
Append("layout(binding = ");
@ -773,19 +780,20 @@ namespace Nz
if (IsUniformType(externalVar.type))
{
Append("_NzBinding_");
AppendLine(externalVar.name);
EnterScope();
{
const Identifier* identifier = FindIdentifier(std::get<ShaderAst::IdentifierType>(std::get<ShaderAst::UniformType>(externalVar.type).containedType).name);
assert(identifier);
auto& uniform = std::get<ShaderAst::UniformType>(externalVar.type);
assert(std::holds_alternative<ShaderAst::StructType>(uniform.containedType));
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
std::size_t structIndex = std::get<ShaderAst::StructType>(uniform.containedType).structIndex;
auto& structInfo = m_currentState->structs[structIndex];
bool first = true;
for (const auto& [name, attribute, type] : s.members)
for (const auto& [name, attribute, type] : structInfo.structDesc.members)
{
if (!first)
AppendLine();
@ -807,6 +815,8 @@ namespace Nz
Append(externalVar.name);
AppendLine(";");
}
RegisterVariable(varIndex++, externalVar.name);
}
}
@ -814,6 +824,9 @@ namespace Nz
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
assert(node.varIndex);
std::size_t varIndex = *node.varIndex;
Append(node.returnType);
Append(" ");
Append(node.name);
@ -825,22 +838,33 @@ namespace Nz
Append(node.parameters[i].type);
Append(" ");
Append(node.parameters[i].name);
RegisterVariable(varIndex++, node.parameters[i].name);
}
Append(")\n");
EnterScope();
PushScope();
{
for (auto& statement : node.statements)
statement->Visit(*this);
}
PopScope();
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node)
{
RegisterStruct(node.description);
bool isStd140 = false;
for (const auto& [attributeType, attributeParam] : node.attributes)
{
if (attributeType == ShaderAst::AttributeType::Layout && std::get<std::string>(attributeParam) == "std140")
{
isStd140 = true;
break;
}
}
assert(node.structIndex);
RegisterStruct(*node.structIndex, isStd140, node.description);
Append("struct ");
AppendLine(node.description.name);
@ -866,7 +890,8 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
RegisterVariable(node.varName, node.varType);
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);
Append(node.varType);
Append(" ");
@ -891,11 +916,6 @@ namespace Nz
AppendLine(";");
}
void GlslWriter::Visit(ShaderAst::IdentifierExpression& node)
{
Append(node.identifier);
}
void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node)
{
switch (node.intrinsic)
@ -926,8 +946,6 @@ namespace Nz
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
{
PushScope();
bool first = true;
for (const ShaderAst::StatementPtr& statement : node.statements)
{
@ -938,8 +956,6 @@ namespace Nz
first = false;
}
PopScope();
}
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)
@ -987,6 +1003,12 @@ namespace Nz
}
}
void GlslWriter::Visit(ShaderAst::VariableExpression& node)
{
const std::string& varName = m_currentState->variableNames[node.variableId];
Append(varName);
}
bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader)
{
/*for (const auto& uniform : shader.GetUniforms())

View File

@ -567,7 +567,6 @@ namespace Nz::ShaderAst
for (const auto& extVar : node.externalVars)
{
bool hasBinding = false;
bool hasLayout = false;
for (const auto& [attributeType, arg] : extVar.attributes)
{
switch (attributeType)
@ -588,21 +587,6 @@ namespace Nz::ShaderAst
break;
}
case AttributeType::Layout:
{
if (hasLayout)
throw AstError{ "attribute layout must be present once" };
if (!std::holds_alternative<std::string>(arg))
throw AstError{ "attribute layout requires a string parameter" };
if (std::get<std::string>(arg) != "std140")
throw AstError{ "unknown layout type" };
hasLayout = true;
break;
}
default:
throw AstError{ "unhandled attribute for external variable" };
}

View File

@ -238,7 +238,7 @@ namespace Nz
Int32(memberIndex),
m_constantCache.Register(*m_constantCache.BuildType(member.type)),
varId
});
});
}
memberIndex++;