Rework shader AST (WIP)
This commit is contained in:
@@ -167,8 +167,8 @@ namespace Nz
|
||||
auto& fragmentShader = settings.shaders[UnderlyingCast(ShaderStageType::Fragment)];
|
||||
auto& vertexShader = settings.shaders[UnderlyingCast(ShaderStageType::Vertex)];
|
||||
|
||||
fragmentShader = std::make_shared<UberShader>(UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader)));
|
||||
vertexShader = std::make_shared<UberShader>(UnserializeShader(r_vertexShader, sizeof(r_vertexShader)));
|
||||
fragmentShader = std::make_shared<UberShader>(ShaderAst::UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader)));
|
||||
vertexShader = std::make_shared<UberShader>(ShaderAst::UnserializeShader(r_vertexShader, sizeof(r_vertexShader)));
|
||||
|
||||
// Conditions
|
||||
|
||||
|
||||
@@ -5,17 +5,17 @@
|
||||
#include <Nazara/Graphics/UberShader.hpp>
|
||||
#include <Nazara/Graphics/Graphics.hpp>
|
||||
#include <Nazara/Renderer/RenderDevice.hpp>
|
||||
#include <Nazara/Shader/ShaderAst.hpp>
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Graphics/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
UberShader::UberShader(ShaderAst shaderAst) :
|
||||
UberShader::UberShader(ShaderAst::StatementPtr shaderAst) :
|
||||
m_shaderAst(std::move(shaderAst))
|
||||
{
|
||||
std::size_t conditionCount = m_shaderAst.GetConditionCount();
|
||||
//std::size_t conditionCount = m_shaderAst.GetConditionCount();
|
||||
std::size_t conditionCount = 0;
|
||||
|
||||
if (conditionCount >= 64)
|
||||
throw std::runtime_error("Too many conditions");
|
||||
@@ -27,10 +27,10 @@ namespace Nz
|
||||
|
||||
UInt64 UberShader::GetConditionFlagByName(const std::string_view& condition) const
|
||||
{
|
||||
std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition);
|
||||
/*std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition);
|
||||
if (conditionIndex != ShaderAst::InvalidCondition)
|
||||
return SetBit<UInt64>(0, conditionIndex);
|
||||
else
|
||||
else*/
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <Nazara/Math/Algorithm.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCloner.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
||||
#include <Nazara/Shader/ShaderAstValidator.hpp>
|
||||
#include <optional>
|
||||
@@ -20,64 +21,67 @@ namespace Nz
|
||||
{
|
||||
static const char* flipYUniformName = "_NzFlipValue";
|
||||
|
||||
struct AstAdapter : ShaderAstCloner
|
||||
struct AstAdapter : ShaderAst::AstCloner
|
||||
{
|
||||
void Visit(ShaderNodes::AssignOp& node) override
|
||||
void Visit(ShaderAst::AssignExpression& node) override
|
||||
{
|
||||
if (!flipYPosition)
|
||||
return AstCloner::Visit(node);
|
||||
|
||||
if (node.left->GetType() != ShaderAst::NodeType::IdentifierExpression)
|
||||
return AstCloner::Visit(node);
|
||||
|
||||
/*
|
||||
FIXME:
|
||||
const auto& identifier = static_cast<const ShaderAst::Identifier&>(*node.left);
|
||||
if (identifier.var->GetType() != ShaderAst::VariableType::BuiltinVariable)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
if (node.left->GetType() != ShaderNodes::NodeType::Identifier)
|
||||
const auto& builtinVar = static_cast<const ShaderAst::BuiltinVariable&>(*identifier.var);
|
||||
if (builtinVar.entry != ShaderAst::BuiltinEntry::VertexPosition)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
const auto& identifier = static_cast<const ShaderNodes::Identifier&>(*node.left);
|
||||
if (identifier.var->GetType() != ShaderNodes::VariableType::BuiltinVariable)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
const auto& builtinVar = static_cast<const ShaderNodes::BuiltinVariable&>(*identifier.var);
|
||||
if (builtinVar.entry != ShaderNodes::BuiltinEntry::VertexPosition)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderNodes::BasicType::Float1);
|
||||
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderAst::BasicType::Float1);
|
||||
|
||||
auto oneConstant = ShaderBuilder::Constant(1.f);
|
||||
auto fixYValue = ShaderBuilder::Cast<ShaderNodes::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
|
||||
auto fixYValue = ShaderBuilder::Cast<ShaderAst::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
|
||||
auto mulFix = ShaderBuilder::Multiply(CloneExpression(node.right), fixYValue);
|
||||
|
||||
PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));
|
||||
PushExpression(ShaderAst::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));*/
|
||||
}
|
||||
|
||||
bool flipYPosition = false;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
struct GlslWriter::State
|
||||
{
|
||||
const States* states = nullptr;
|
||||
ShaderAst::AstCache cache;
|
||||
std::stringstream stream;
|
||||
unsigned int indentLevel = 0;
|
||||
};
|
||||
|
||||
|
||||
GlslWriter::GlslWriter() :
|
||||
m_currentState(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
std::string GlslWriter::Generate(const ShaderAst& inputShader, const States& conditions)
|
||||
std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
|
||||
{
|
||||
const ShaderAst* selectedShader = &inputShader;
|
||||
/*const ShaderAst* selectedShader = &inputShader;
|
||||
std::optional<ShaderAst> modifiedShader;
|
||||
if (inputShader.GetStage() == ShaderStageType::Vertex && m_environment.flipYPosition)
|
||||
{
|
||||
modifiedShader.emplace(inputShader);
|
||||
|
||||
modifiedShader->AddUniform(flipYUniformName, ShaderNodes::BasicType::Float1);
|
||||
modifiedShader->AddUniform(flipYUniformName, ShaderAst::BasicType::Float1);
|
||||
|
||||
selectedShader = &modifiedShader.value();
|
||||
}
|
||||
|
||||
const ShaderAst& shader = *selectedShader;
|
||||
|
||||
std::string error;
|
||||
if (!ValidateShader(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.states = &conditions;
|
||||
m_context.shader = &shader;
|
||||
|
||||
}*/
|
||||
|
||||
State state;
|
||||
m_currentState = &state;
|
||||
CallOnExit onExit([this]()
|
||||
@@ -85,6 +89,10 @@ namespace Nz
|
||||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
std::string error;
|
||||
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
unsigned int glslVersion;
|
||||
if (m_environment.glES)
|
||||
{
|
||||
@@ -165,52 +173,7 @@ namespace Nz
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
// Structures
|
||||
/*if (shader.GetStructCount() > 0)
|
||||
{
|
||||
AppendCommentSection("Structures");
|
||||
for (const auto& s : shader.GetStructs())
|
||||
{
|
||||
Append("struct ");
|
||||
AppendLine(s.name);
|
||||
AppendLine("{");
|
||||
for (const auto& m : s.members)
|
||||
{
|
||||
Append("\t");
|
||||
Append(m.type);
|
||||
Append(" ");
|
||||
Append(m.name);
|
||||
AppendLine(";");
|
||||
}
|
||||
AppendLine("};");
|
||||
AppendLine();
|
||||
}
|
||||
}*/
|
||||
|
||||
// Global variables (uniforms, input and outputs)
|
||||
const char* inKeyword = (glslVersion >= 130) ? "in" : "varying";
|
||||
const char* outKeyword = (glslVersion >= 130) ? "out" : "varying";
|
||||
|
||||
DeclareVariables(shader, shader.GetUniforms(), "uniform", "Uniforms");
|
||||
DeclareVariables(shader, shader.GetInputs(), inKeyword, "Inputs");
|
||||
DeclareVariables(shader, shader.GetOutputs(), outKeyword, "Outputs");
|
||||
|
||||
std::size_t functionCount = shader.GetFunctionCount();
|
||||
if (functionCount > 1)
|
||||
{
|
||||
AppendCommentSection("Prototypes");
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
if (func.name != "main")
|
||||
{
|
||||
AppendFunctionPrototype(func);
|
||||
AppendLine(";");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
AppendFunction(func);
|
||||
shader->Visit(*this);
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
@@ -225,7 +188,7 @@ namespace Nz
|
||||
return flipYUniformName;
|
||||
}
|
||||
|
||||
void GlslWriter::Append(ShaderExpressionType type)
|
||||
void GlslWriter::Append(ShaderAst::ShaderExpressionType type)
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
@@ -233,49 +196,57 @@ namespace Nz
|
||||
}, type);
|
||||
}
|
||||
|
||||
void GlslWriter::Append(ShaderNodes::BuiltinEntry builtin)
|
||||
void GlslWriter::Append(ShaderAst::BuiltinEntry builtin)
|
||||
{
|
||||
switch (builtin)
|
||||
{
|
||||
case ShaderNodes::BuiltinEntry::VertexPosition:
|
||||
case ShaderAst::BuiltinEntry::VertexPosition:
|
||||
Append("gl_Position");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Append(ShaderNodes::BasicType type)
|
||||
void GlslWriter::Append(ShaderAst::BasicType type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case ShaderNodes::BasicType::Boolean: return Append("bool");
|
||||
case ShaderNodes::BasicType::Float1: return Append("float");
|
||||
case ShaderNodes::BasicType::Float2: return Append("vec2");
|
||||
case ShaderNodes::BasicType::Float3: return Append("vec3");
|
||||
case ShaderNodes::BasicType::Float4: return Append("vec4");
|
||||
case ShaderNodes::BasicType::Int1: return Append("int");
|
||||
case ShaderNodes::BasicType::Int2: return Append("ivec2");
|
||||
case ShaderNodes::BasicType::Int3: return Append("ivec3");
|
||||
case ShaderNodes::BasicType::Int4: return Append("ivec4");
|
||||
case ShaderNodes::BasicType::Mat4x4: return Append("mat4");
|
||||
case ShaderNodes::BasicType::Sampler2D: return Append("sampler2D");
|
||||
case ShaderNodes::BasicType::UInt1: return Append("uint");
|
||||
case ShaderNodes::BasicType::UInt2: return Append("uvec2");
|
||||
case ShaderNodes::BasicType::UInt3: return Append("uvec3");
|
||||
case ShaderNodes::BasicType::UInt4: return Append("uvec4");
|
||||
case ShaderNodes::BasicType::Void: return Append("void");
|
||||
case ShaderAst::BasicType::Boolean: return Append("bool");
|
||||
case ShaderAst::BasicType::Float1: return Append("float");
|
||||
case ShaderAst::BasicType::Float2: return Append("vec2");
|
||||
case ShaderAst::BasicType::Float3: return Append("vec3");
|
||||
case ShaderAst::BasicType::Float4: return Append("vec4");
|
||||
case ShaderAst::BasicType::Int1: return Append("int");
|
||||
case ShaderAst::BasicType::Int2: return Append("ivec2");
|
||||
case ShaderAst::BasicType::Int3: return Append("ivec3");
|
||||
case ShaderAst::BasicType::Int4: return Append("ivec4");
|
||||
case ShaderAst::BasicType::Mat4x4: return Append("mat4");
|
||||
case ShaderAst::BasicType::Sampler2D: return Append("sampler2D");
|
||||
case ShaderAst::BasicType::UInt1: return Append("uint");
|
||||
case ShaderAst::BasicType::UInt2: return Append("uvec2");
|
||||
case ShaderAst::BasicType::UInt3: return Append("uvec3");
|
||||
case ShaderAst::BasicType::UInt4: return Append("uvec4");
|
||||
case ShaderAst::BasicType::Void: return Append("void");
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Append(ShaderNodes::MemoryLayout layout)
|
||||
void GlslWriter::Append(ShaderAst::MemoryLayout layout)
|
||||
{
|
||||
switch (layout)
|
||||
{
|
||||
case ShaderNodes::MemoryLayout::Std140:
|
||||
case ShaderAst::MemoryLayout::Std140:
|
||||
Append("std140");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GlslWriter::Append(const T& param)
|
||||
{
|
||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||
|
||||
m_currentState->stream << param;
|
||||
}
|
||||
|
||||
void GlslWriter::AppendCommentSection(const std::string& section)
|
||||
{
|
||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||
@@ -285,67 +256,24 @@ namespace Nz
|
||||
AppendLine();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers)
|
||||
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
||||
{
|
||||
const auto& structs = m_context.shader->GetStructs();
|
||||
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
|
||||
assert(it != structs.end());
|
||||
|
||||
const ShaderAst::Struct& s = *it;
|
||||
assert(*memberIndex < s.members.size());
|
||||
|
||||
const auto& member = s.members[*memberIndex];
|
||||
Append(".");
|
||||
Append(member.name);
|
||||
Append(memberIdentifier[0]);
|
||||
|
||||
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName);
|
||||
assert(identifier);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
|
||||
|
||||
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
|
||||
assert(memberIt != s.members.end());
|
||||
|
||||
const auto& member = *memberIt;
|
||||
|
||||
if (remainingMembers > 1)
|
||||
{
|
||||
assert(IsStructType(member.type));
|
||||
AppendField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::AppendFunction(const ShaderAst::Function& func)
|
||||
{
|
||||
NazaraAssert(!m_context.currentFunction, "A function is already being processed");
|
||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||
|
||||
AppendFunctionPrototype(func);
|
||||
|
||||
m_context.currentFunction = &func;
|
||||
CallOnExit onExit([this] ()
|
||||
{
|
||||
m_context.currentFunction = nullptr;
|
||||
});
|
||||
|
||||
EnterScope();
|
||||
{
|
||||
AstAdapter adapter;
|
||||
adapter.flipYPosition = m_environment.flipYPosition;
|
||||
|
||||
Visit(adapter.Clone(func.statement));
|
||||
}
|
||||
LeaveScope();
|
||||
}
|
||||
|
||||
void GlslWriter::AppendFunctionPrototype(const ShaderAst::Function& func)
|
||||
{
|
||||
Append(func.returnType);
|
||||
|
||||
Append(" ");
|
||||
Append(func.name);
|
||||
|
||||
Append("(");
|
||||
for (std::size_t i = 0; i < func.parameters.size(); ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
Append(", ");
|
||||
|
||||
Append(func.parameters[i].type);
|
||||
Append(" ");
|
||||
Append(func.parameters[i].name);
|
||||
}
|
||||
Append(")\n");
|
||||
AppendField(scopeId, std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
|
||||
}
|
||||
|
||||
void GlslWriter::AppendLine(const std::string& txt)
|
||||
@@ -372,44 +300,46 @@ namespace Nz
|
||||
AppendLine("}");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired)
|
||||
void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
|
||||
{
|
||||
bool enclose = encloseIfRequired && (GetExpressionCategory(expr) != ShaderNodes::ExpressionCategory::LValue);
|
||||
bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue);
|
||||
|
||||
if (enclose)
|
||||
Append("(");
|
||||
|
||||
ShaderAstVisitor::Visit(expr);
|
||||
expr->Visit(*this);
|
||||
|
||||
if (enclose)
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::AccessMember& node)
|
||||
void GlslWriter::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
{
|
||||
Visit(node.structExpr, true);
|
||||
|
||||
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
|
||||
const ShaderAst::ShaderExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
|
||||
assert(IsStructType(exprType));
|
||||
|
||||
AppendField(std::get<std::string>(exprType), node.memberIndices.data(), node.memberIndices.size());
|
||||
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
|
||||
|
||||
AppendField(scopeId, std::get<std::string>(exprType), node.memberIdentifiers.data(), node.memberIdentifiers.size());
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::AssignOp& node)
|
||||
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
|
||||
{
|
||||
Visit(node.left);
|
||||
node.left->Visit(*this);
|
||||
|
||||
switch (node.op)
|
||||
{
|
||||
case ShaderNodes::AssignType::Simple:
|
||||
case ShaderAst::AssignType::Simple:
|
||||
Append(" = ");
|
||||
break;
|
||||
}
|
||||
|
||||
Visit(node.right);
|
||||
node.left->Visit(*this);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::Branch& node)
|
||||
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
|
||||
{
|
||||
bool first = true;
|
||||
for (const auto& statement : node.condStatements)
|
||||
@@ -418,11 +348,11 @@ namespace Nz
|
||||
Append("else ");
|
||||
|
||||
Append("if (");
|
||||
Visit(statement.condition);
|
||||
statement.condition->Visit(*this);
|
||||
AppendLine(")");
|
||||
|
||||
EnterScope();
|
||||
Visit(statement.statement);
|
||||
statement.statement->Visit(*this);
|
||||
LeaveScope();
|
||||
|
||||
first = false;
|
||||
@@ -433,41 +363,36 @@ namespace Nz
|
||||
AppendLine("else");
|
||||
|
||||
EnterScope();
|
||||
Visit(node.elseStatement);
|
||||
node.elseStatement->Visit(*this);
|
||||
LeaveScope();
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::BinaryOp& node)
|
||||
void GlslWriter::Visit(ShaderAst::BinaryExpression& node)
|
||||
{
|
||||
Visit(node.left, true);
|
||||
|
||||
switch (node.op)
|
||||
{
|
||||
case ShaderNodes::BinaryType::Add: Append(" + "); break;
|
||||
case ShaderNodes::BinaryType::Subtract: Append(" - "); break;
|
||||
case ShaderNodes::BinaryType::Multiply: Append(" * "); break;
|
||||
case ShaderNodes::BinaryType::Divide: Append(" / "); break;
|
||||
case ShaderAst::BinaryType::Add: Append(" + "); break;
|
||||
case ShaderAst::BinaryType::Subtract: Append(" - "); break;
|
||||
case ShaderAst::BinaryType::Multiply: Append(" * "); break;
|
||||
case ShaderAst::BinaryType::Divide: Append(" / "); break;
|
||||
|
||||
case ShaderNodes::BinaryType::CompEq: Append(" == "); break;
|
||||
case ShaderNodes::BinaryType::CompGe: Append(" >= "); break;
|
||||
case ShaderNodes::BinaryType::CompGt: Append(" > "); break;
|
||||
case ShaderNodes::BinaryType::CompLe: Append(" <= "); break;
|
||||
case ShaderNodes::BinaryType::CompLt: Append(" < "); break;
|
||||
case ShaderNodes::BinaryType::CompNe: Append(" != "); break;
|
||||
case ShaderAst::BinaryType::CompEq: Append(" == "); break;
|
||||
case ShaderAst::BinaryType::CompGe: Append(" >= "); break;
|
||||
case ShaderAst::BinaryType::CompGt: Append(" > "); break;
|
||||
case ShaderAst::BinaryType::CompLe: Append(" <= "); break;
|
||||
case ShaderAst::BinaryType::CompLt: Append(" < "); break;
|
||||
case ShaderAst::BinaryType::CompNe: Append(" != "); break;
|
||||
}
|
||||
|
||||
Visit(node.right, true);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::BuiltinVariable& var)
|
||||
void GlslWriter::Visit(ShaderAst::CastExpression& node)
|
||||
{
|
||||
Append(var.entry);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::Cast& node)
|
||||
{
|
||||
Append(node.exprType);
|
||||
Append(node.targetType);
|
||||
Append("(");
|
||||
|
||||
bool first = true;
|
||||
@@ -479,34 +404,34 @@ namespace Nz
|
||||
if (!first)
|
||||
m_currentState->stream << ", ";
|
||||
|
||||
Visit(exprPtr);
|
||||
exprPtr->Visit(*this);
|
||||
first = false;
|
||||
}
|
||||
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
void GlslWriter::Visit(ShaderAst::ConditionalExpression& node)
|
||||
{
|
||||
std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
|
||||
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
Visit(node.falsePath);*/
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
void GlslWriter::Visit(ShaderAst::ConditionalStatement& node)
|
||||
{
|
||||
std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
|
||||
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
|
||||
Visit(node.statement);
|
||||
Visit(node.statement);*/
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::Constant& node)
|
||||
void GlslWriter::Visit(ShaderAst::ConstantExpression& node)
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
@@ -530,54 +455,74 @@ namespace Nz
|
||||
}, node.value);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::DeclareVariable& node)
|
||||
void GlslWriter::Visit(ShaderAst::DeclareFunctionStatement& node)
|
||||
{
|
||||
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
|
||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||
|
||||
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
|
||||
|
||||
Append(localVar.type);
|
||||
Append(node.returnType);
|
||||
Append(" ");
|
||||
Append(localVar.name);
|
||||
if (node.expression)
|
||||
Append(node.name);
|
||||
Append("(");
|
||||
for (std::size_t i = 0; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
Append(", ");
|
||||
Append(node.parameters[i].type);
|
||||
Append(" ");
|
||||
Append(node.parameters[i].name);
|
||||
}
|
||||
Append(")\n");
|
||||
|
||||
EnterScope();
|
||||
{
|
||||
AstAdapter adapter;
|
||||
adapter.flipYPosition = m_environment.flipYPosition;
|
||||
|
||||
for (auto& statement : node.statements)
|
||||
adapter.Clone(statement)->Visit(*this);
|
||||
}
|
||||
LeaveScope();
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||
{
|
||||
Append(node.varType);
|
||||
Append(" ");
|
||||
Append(node.varName);
|
||||
if (node.initialExpression)
|
||||
{
|
||||
Append(" = ");
|
||||
Visit(node.expression);
|
||||
node.initialExpression->Visit(*this);
|
||||
}
|
||||
|
||||
AppendLine(";");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::Discard& /*node*/)
|
||||
void GlslWriter::Visit(ShaderAst::DiscardStatement& /*node*/)
|
||||
{
|
||||
Append("discard;");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ExpressionStatement& node)
|
||||
void GlslWriter::Visit(ShaderAst::ExpressionStatement& node)
|
||||
{
|
||||
Visit(node.expression);
|
||||
node.expression->Visit(*this);
|
||||
Append(";");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::Identifier& node)
|
||||
void GlslWriter::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
Visit(node.var);
|
||||
Append(node.identifier);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::InputVariable& var)
|
||||
{
|
||||
Append(var.name);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::IntrinsicCall& node)
|
||||
void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node)
|
||||
{
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case ShaderNodes::IntrinsicType::CrossProduct:
|
||||
case ShaderAst::IntrinsicType::CrossProduct:
|
||||
Append("cross");
|
||||
break;
|
||||
|
||||
case ShaderNodes::IntrinsicType::DotProduct:
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
Append("dot");
|
||||
break;
|
||||
}
|
||||
@@ -588,67 +533,43 @@ namespace Nz
|
||||
if (i != 0)
|
||||
Append(", ");
|
||||
|
||||
Visit(node.parameters[i]);
|
||||
node.parameters[i]->Visit(*this);
|
||||
}
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::LocalVariable& var)
|
||||
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
|
||||
{
|
||||
Append(var.name);
|
||||
bool first = true;
|
||||
for (const ShaderAst::StatementPtr& statement : node.statements)
|
||||
{
|
||||
if (!first && statement->GetType() != ShaderAst::NodeType::NoOpStatement)
|
||||
AppendLine();
|
||||
|
||||
statement->Visit(*this);
|
||||
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::NoOp& /*node*/)
|
||||
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)
|
||||
{
|
||||
/* nothing to do */
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ParameterVariable& var)
|
||||
{
|
||||
Append(var.name);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::ReturnStatement& node)
|
||||
void GlslWriter::Visit(ShaderAst::ReturnStatement& node)
|
||||
{
|
||||
if (node.returnExpr)
|
||||
{
|
||||
Append("return ");
|
||||
Visit(node.returnExpr);
|
||||
node.returnExpr->Visit(*this);
|
||||
Append(";");
|
||||
}
|
||||
else
|
||||
Append("return;");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::OutputVariable& var)
|
||||
{
|
||||
Append(var.name);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::Sample2D& node)
|
||||
{
|
||||
Append("texture(");
|
||||
Visit(node.sampler);
|
||||
Append(", ");
|
||||
Visit(node.coordinates);
|
||||
Append(")");
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::StatementBlock& node)
|
||||
{
|
||||
bool first = true;
|
||||
for (const ShaderNodes::StatementPtr& statement : node.statements)
|
||||
{
|
||||
if (!first && statement->GetType() != ShaderNodes::NodeType::NoOp)
|
||||
AppendLine();
|
||||
|
||||
Visit(statement);
|
||||
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::SwizzleOp& node)
|
||||
void GlslWriter::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
Visit(node.expression, true);
|
||||
Append(".");
|
||||
@@ -657,44 +578,39 @@ namespace Nz
|
||||
{
|
||||
switch (node.components[i])
|
||||
{
|
||||
case ShaderNodes::SwizzleComponent::First:
|
||||
case ShaderAst::SwizzleComponent::First:
|
||||
Append("x");
|
||||
break;
|
||||
|
||||
case ShaderNodes::SwizzleComponent::Second:
|
||||
case ShaderAst::SwizzleComponent::Second:
|
||||
Append("y");
|
||||
break;
|
||||
|
||||
case ShaderNodes::SwizzleComponent::Third:
|
||||
case ShaderAst::SwizzleComponent::Third:
|
||||
Append("z");
|
||||
break;
|
||||
|
||||
case ShaderNodes::SwizzleComponent::Fourth:
|
||||
case ShaderAst::SwizzleComponent::Fourth:
|
||||
Append("w");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(ShaderNodes::UniformVariable& var)
|
||||
bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader)
|
||||
{
|
||||
Append(var.name);
|
||||
}
|
||||
|
||||
bool GlslWriter::HasExplicitBinding(const ShaderAst& shader)
|
||||
{
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
/*for (const auto& uniform : shader.GetUniforms())
|
||||
{
|
||||
if (uniform.bindingIndex.has_value())
|
||||
return true;
|
||||
}
|
||||
}*/
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GlslWriter::HasExplicitLocation(const ShaderAst& shader)
|
||||
bool GlslWriter::HasExplicitLocation(ShaderAst::StatementPtr& shader)
|
||||
{
|
||||
for (const auto& input : shader.GetInputs())
|
||||
/*for (const auto& input : shader.GetInputs())
|
||||
{
|
||||
if (input.locationIndex.has_value())
|
||||
return true;
|
||||
@@ -704,7 +620,7 @@ namespace Nz
|
||||
{
|
||||
if (output.locationIndex.has_value())
|
||||
return true;
|
||||
}
|
||||
}*/
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAst.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
void ShaderAst::AddCondition(std::string name)
|
||||
{
|
||||
auto& conditionEntry = m_conditions.emplace_back();
|
||||
conditionEntry.name = std::move(name);
|
||||
}
|
||||
|
||||
void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters, ShaderExpressionType returnType)
|
||||
{
|
||||
auto& functionEntry = m_functions.emplace_back();
|
||||
functionEntry.name = std::move(name);
|
||||
functionEntry.parameters = std::move(parameters);
|
||||
functionEntry.returnType = returnType;
|
||||
functionEntry.statement = std::move(statement);
|
||||
}
|
||||
|
||||
void ShaderAst::AddInput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex)
|
||||
{
|
||||
auto& inputEntry = m_inputs.emplace_back();
|
||||
inputEntry.name = std::move(name);
|
||||
inputEntry.locationIndex = std::move(locationIndex);
|
||||
inputEntry.type = std::move(type);
|
||||
}
|
||||
|
||||
void ShaderAst::AddOutput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex)
|
||||
{
|
||||
auto& outputEntry = m_outputs.emplace_back();
|
||||
outputEntry.name = std::move(name);
|
||||
outputEntry.locationIndex = std::move(locationIndex);
|
||||
outputEntry.type = std::move(type);
|
||||
}
|
||||
|
||||
void ShaderAst::AddStruct(std::string name, std::vector<StructMember> members)
|
||||
{
|
||||
auto& structEntry = m_structs.emplace_back();
|
||||
structEntry.name = std::move(name);
|
||||
structEntry.members = std::move(members);
|
||||
}
|
||||
|
||||
void ShaderAst::AddUniform(std::string name, ShaderExpressionType type, std::optional<std::size_t> bindingIndex, std::optional<ShaderNodes::MemoryLayout> memoryLayout)
|
||||
{
|
||||
auto& uniformEntry = m_uniforms.emplace_back();
|
||||
uniformEntry.bindingIndex = std::move(bindingIndex);
|
||||
uniformEntry.memoryLayout = std::move(memoryLayout);
|
||||
uniformEntry.name = std::move(name);
|
||||
uniformEntry.type = std::move(type);
|
||||
}
|
||||
}
|
||||
@@ -6,240 +6,257 @@
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
ShaderNodes::StatementPtr ShaderAstCloner::Clone(const ShaderNodes::StatementPtr& statement)
|
||||
ExpressionPtr AstCloner::Clone(ExpressionPtr& expr)
|
||||
{
|
||||
ShaderAstVisitor::Visit(statement);
|
||||
expr->Visit(*this);
|
||||
|
||||
if (!m_expressionStack.empty() || !m_variableStack.empty() || m_statementStack.size() != 1)
|
||||
throw std::runtime_error("An error occurred during clone");
|
||||
assert(m_statementStack.empty() && m_expressionStack.size() == 1);
|
||||
return PopExpression();
|
||||
}
|
||||
|
||||
StatementPtr AstCloner::Clone(StatementPtr& statement)
|
||||
{
|
||||
statement->Visit(*this);
|
||||
|
||||
assert(m_expressionStack.empty() && m_statementStack.size() == 1);
|
||||
return PopStatement();
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr ShaderAstCloner::CloneExpression(const ShaderNodes::ExpressionPtr& expr)
|
||||
ExpressionPtr AstCloner::CloneExpression(ExpressionPtr& expr)
|
||||
{
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
|
||||
ShaderAstVisitor::Visit(expr);
|
||||
expr->Visit(*this);
|
||||
return PopExpression();
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr ShaderAstCloner::CloneStatement(const ShaderNodes::StatementPtr& statement)
|
||||
StatementPtr AstCloner::CloneStatement(StatementPtr& statement)
|
||||
{
|
||||
if (!statement)
|
||||
return nullptr;
|
||||
|
||||
ShaderAstVisitor::Visit(statement);
|
||||
statement->Visit(*this);
|
||||
return PopStatement();
|
||||
}
|
||||
|
||||
ShaderNodes::VariablePtr ShaderAstCloner::CloneVariable(const ShaderNodes::VariablePtr& variable)
|
||||
void AstCloner::Visit(AccessMemberExpression& node)
|
||||
{
|
||||
if (!variable)
|
||||
return nullptr;
|
||||
auto clone = std::make_unique<AccessMemberExpression>();
|
||||
clone->memberIdentifiers = node.memberIdentifiers;
|
||||
clone->structExpr = CloneExpression(node.structExpr);
|
||||
|
||||
ShaderVarVisitor::Visit(variable);
|
||||
return PopVariable();
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::AccessMember& node)
|
||||
void AstCloner::Visit(AssignExpression& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::AccessMember::Build(CloneExpression(node.structExpr), node.memberIndices, node.exprType));
|
||||
auto clone = std::make_unique<AssignExpression>();
|
||||
clone->op = node.op;
|
||||
clone->left = CloneExpression(node.left);
|
||||
clone->right = CloneExpression(node.right);
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::AssignOp& node)
|
||||
void AstCloner::Visit(BinaryExpression& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right)));
|
||||
auto clone = std::make_unique<BinaryExpression>();
|
||||
clone->op = node.op;
|
||||
clone->left = CloneExpression(node.left);
|
||||
clone->right = CloneExpression(node.right);
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::BinaryOp& node)
|
||||
void AstCloner::Visit(CastExpression& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::BinaryOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right)));
|
||||
}
|
||||
auto clone = std::make_unique<CastExpression>();
|
||||
clone->targetType = node.targetType;
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::Branch& node)
|
||||
{
|
||||
std::vector<ShaderNodes::Branch::ConditionalStatement> condStatements;
|
||||
condStatements.reserve(node.condStatements.size());
|
||||
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
auto& condStatement = condStatements.emplace_back();
|
||||
condStatement.condition = CloneExpression(cond.condition);
|
||||
condStatement.statement = CloneStatement(cond.statement);
|
||||
}
|
||||
|
||||
PushStatement(ShaderNodes::Branch::Build(std::move(condStatements), CloneStatement(node.elseStatement)));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::Cast& node)
|
||||
{
|
||||
std::size_t expressionCount = 0;
|
||||
std::array<ShaderNodes::ExpressionPtr, 4> expressions;
|
||||
for (auto& expr : node.expressions)
|
||||
{
|
||||
if (!expr)
|
||||
break;
|
||||
|
||||
expressions[expressionCount] = CloneExpression(expr);
|
||||
expressionCount++;
|
||||
clone->expressions[expressionCount++] = CloneExpression(expr);
|
||||
}
|
||||
|
||||
PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount));
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
void AstCloner::Visit(ConditionalExpression& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::ConditionalExpression::Build(node.conditionName, CloneExpression(node.truePath), CloneExpression(node.falsePath)));
|
||||
auto clone = std::make_unique<ConditionalExpression>();
|
||||
clone->conditionName = node.conditionName;
|
||||
clone->falsePath = CloneExpression(node.falsePath);
|
||||
clone->truePath = CloneExpression(node.truePath);
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
void AstCloner::Visit(ConstantExpression& node)
|
||||
{
|
||||
PushStatement(ShaderNodes::ConditionalStatement::Build(node.conditionName, CloneStatement(node.statement)));
|
||||
auto clone = std::make_unique<ConstantExpression>();
|
||||
clone->value = node.value;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::Constant& node)
|
||||
void AstCloner::Visit(IdentifierExpression& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::Constant::Build(node.value));
|
||||
auto clone = std::make_unique<IdentifierExpression>();
|
||||
clone->identifier = node.identifier;
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::DeclareVariable& node)
|
||||
void AstCloner::Visit(IntrinsicExpression& node)
|
||||
{
|
||||
PushStatement(ShaderNodes::DeclareVariable::Build(CloneVariable(node.variable), CloneExpression(node.expression)));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::Discard& /*node*/)
|
||||
{
|
||||
PushStatement(ShaderNodes::Discard::Build());
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::ExpressionStatement& node)
|
||||
{
|
||||
PushStatement(ShaderNodes::ExpressionStatement::Build(CloneExpression(node.expression)));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::Identifier& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::Identifier::Build(CloneVariable(node.var)));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::IntrinsicCall& node)
|
||||
{
|
||||
std::vector<ShaderNodes::ExpressionPtr> parameters;
|
||||
parameters.reserve(node.parameters.size());
|
||||
auto clone = std::make_unique<IntrinsicExpression>();
|
||||
clone->intrinsic = node.intrinsic;
|
||||
|
||||
clone->parameters.reserve(node.parameters.size());
|
||||
for (auto& parameter : node.parameters)
|
||||
parameters.push_back(CloneExpression(parameter));
|
||||
clone->parameters.push_back(CloneExpression(parameter));
|
||||
|
||||
PushExpression(ShaderNodes::IntrinsicCall::Build(node.intrinsic, std::move(parameters)));
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::NoOp& /*node*/)
|
||||
void AstCloner::Visit(SwizzleExpression& node)
|
||||
{
|
||||
PushStatement(ShaderNodes::NoOp::Build());
|
||||
auto clone = std::make_unique<SwizzleExpression>();
|
||||
clone->componentCount = node.componentCount;
|
||||
clone->components = node.components;
|
||||
clone->expression = CloneExpression(node.expression);
|
||||
|
||||
PushExpression(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::ReturnStatement& node)
|
||||
void AstCloner::Visit(BranchStatement& node)
|
||||
{
|
||||
PushStatement(ShaderNodes::ReturnStatement::Build(CloneExpression(node.returnExpr)));
|
||||
auto clone = std::make_unique<BranchStatement>();
|
||||
clone->condStatements.reserve(node.condStatements.size());
|
||||
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
auto& condStatement = clone->condStatements.emplace_back();
|
||||
condStatement.condition = CloneExpression(cond.condition);
|
||||
condStatement.statement = CloneStatement(cond.statement);
|
||||
}
|
||||
|
||||
clone->elseStatement = CloneStatement(node.elseStatement);
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::Sample2D& node)
|
||||
void AstCloner::Visit(ConditionalStatement& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::Sample2D::Build(CloneExpression(node.sampler), CloneExpression(node.coordinates)));
|
||||
auto clone = std::make_unique<ConditionalStatement>();
|
||||
clone->conditionName = node.conditionName;
|
||||
clone->statement = CloneStatement(node.statement);
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::StatementBlock& node)
|
||||
void AstCloner::Visit(DeclareFunctionStatement& node)
|
||||
{
|
||||
std::vector<ShaderNodes::StatementPtr> statements;
|
||||
statements.reserve(node.statements.size());
|
||||
auto clone = std::make_unique<DeclareFunctionStatement>();
|
||||
clone->name = node.name;
|
||||
clone->parameters = node.parameters;
|
||||
clone->returnType = node.returnType;
|
||||
|
||||
clone->statements.reserve(node.statements.size());
|
||||
for (auto& statement : node.statements)
|
||||
statements.push_back(CloneStatement(statement));
|
||||
clone->statements.push_back(CloneStatement(statement));
|
||||
|
||||
PushStatement(ShaderNodes::StatementBlock::Build(std::move(statements)));
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::SwizzleOp& node)
|
||||
void AstCloner::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
PushExpression(ShaderNodes::SwizzleOp::Build(CloneExpression(node.expression), node.components.data(), node.componentCount));
|
||||
auto clone = std::make_unique<DeclareStructStatement>();
|
||||
clone->description = node.description;
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::BuiltinVariable& var)
|
||||
void AstCloner::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
PushVariable(ShaderNodes::BuiltinVariable::Build(var.entry, var.type));
|
||||
auto clone = std::make_unique<DeclareVariableStatement>();
|
||||
clone->varName = node.varName;
|
||||
clone->varType = node.varType;
|
||||
clone->initialExpression = CloneExpression(node.initialExpression);
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::InputVariable& var)
|
||||
void AstCloner::Visit(DiscardStatement& /*node*/)
|
||||
{
|
||||
PushVariable(ShaderNodes::InputVariable::Build(var.name, var.type));
|
||||
PushStatement(std::make_unique<DiscardStatement>());
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::LocalVariable& var)
|
||||
void AstCloner::Visit(ExpressionStatement& node)
|
||||
{
|
||||
PushVariable(ShaderNodes::LocalVariable::Build(var.name, var.type));
|
||||
auto clone = std::make_unique<ExpressionStatement>();
|
||||
clone->expression = CloneExpression(node.expression);
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::OutputVariable& var)
|
||||
void AstCloner::Visit(MultiStatement& node)
|
||||
{
|
||||
PushVariable(ShaderNodes::OutputVariable::Build(var.name, var.type));
|
||||
auto clone = std::make_unique<MultiStatement>();
|
||||
clone->statements.reserve(node.statements.size());
|
||||
for (auto& statement : node.statements)
|
||||
clone->statements.push_back(CloneStatement(statement));
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::ParameterVariable& var)
|
||||
void AstCloner::Visit(NoOpStatement& /*node*/)
|
||||
{
|
||||
PushVariable(ShaderNodes::ParameterVariable::Build(var.name, var.type));
|
||||
PushStatement(std::make_unique<NoOpStatement>());
|
||||
}
|
||||
|
||||
void ShaderAstCloner::Visit(ShaderNodes::UniformVariable& var)
|
||||
void AstCloner::Visit(ReturnStatement& node)
|
||||
{
|
||||
PushVariable(ShaderNodes::UniformVariable::Build(var.name, var.type));
|
||||
auto clone = std::make_unique<ReturnStatement>();
|
||||
clone->returnExpr = CloneExpression(node.returnExpr);
|
||||
|
||||
PushStatement(std::move(clone));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::PushExpression(ShaderNodes::ExpressionPtr expression)
|
||||
void AstCloner::PushExpression(ExpressionPtr expression)
|
||||
{
|
||||
m_expressionStack.emplace_back(std::move(expression));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::PushStatement(ShaderNodes::StatementPtr statement)
|
||||
void AstCloner::PushStatement(StatementPtr statement)
|
||||
{
|
||||
m_statementStack.emplace_back(std::move(statement));
|
||||
}
|
||||
|
||||
void ShaderAstCloner::PushVariable(ShaderNodes::VariablePtr variable)
|
||||
{
|
||||
m_variableStack.emplace_back(std::move(variable));
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr ShaderAstCloner::PopExpression()
|
||||
ExpressionPtr AstCloner::PopExpression()
|
||||
{
|
||||
assert(!m_expressionStack.empty());
|
||||
|
||||
ShaderNodes::ExpressionPtr expr = std::move(m_expressionStack.back());
|
||||
ExpressionPtr expr = std::move(m_expressionStack.back());
|
||||
m_expressionStack.pop_back();
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr ShaderAstCloner::PopStatement()
|
||||
StatementPtr AstCloner::PopStatement()
|
||||
{
|
||||
assert(!m_statementStack.empty());
|
||||
|
||||
ShaderNodes::StatementPtr expr = std::move(m_statementStack.back());
|
||||
StatementPtr expr = std::move(m_statementStack.back());
|
||||
m_statementStack.pop_back();
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
ShaderNodes::VariablePtr ShaderAstCloner::PopVariable()
|
||||
{
|
||||
assert(!m_variableStack.empty());
|
||||
|
||||
ShaderNodes::VariablePtr var = std::move(m_variableStack.back());
|
||||
m_variableStack.pop_back();
|
||||
|
||||
return var;
|
||||
}
|
||||
}
|
||||
|
||||
198
src/Nazara/Shader/ShaderAstExpressionType.cpp
Normal file
198
src/Nazara/Shader/ShaderAstExpressionType.cpp
Normal file
@@ -0,0 +1,198 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCache.hpp>
|
||||
#include <optional>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
ShaderExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache = nullptr)
|
||||
{
|
||||
m_cache = cache;
|
||||
ShaderExpressionType type = GetExpressionTypeInternal(expression);
|
||||
m_cache = nullptr;
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
ShaderExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
|
||||
{
|
||||
m_lastExpressionType.reset();
|
||||
|
||||
Visit(expression);
|
||||
|
||||
assert(m_lastExpressionType.has_value());
|
||||
return std::move(*m_lastExpressionType);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(Expression& expression)
|
||||
{
|
||||
if (m_cache)
|
||||
{
|
||||
auto it = m_cache->nodeExpressionType.find(&expression);
|
||||
if (it != m_cache->nodeExpressionType.end())
|
||||
{
|
||||
m_lastExpressionType = it->second;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
expression.Visit(*this);
|
||||
|
||||
if (m_cache)
|
||||
{
|
||||
assert(m_lastExpressionType.has_value());
|
||||
m_cache->nodeExpressionType.emplace(&expression, *m_lastExpressionType);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(AccessMemberExpression& node)
|
||||
{
|
||||
throw std::runtime_error("unhandled accessmember expression");
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(AssignExpression& node)
|
||||
{
|
||||
Visit(*node.left);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(BinaryExpression& node)
|
||||
{
|
||||
switch (node.op)
|
||||
{
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
return Visit(*node.left);
|
||||
|
||||
case BinaryType::Divide:
|
||||
case BinaryType::Multiply:
|
||||
{
|
||||
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.left);
|
||||
assert(IsBasicType(leftExprType));
|
||||
|
||||
ShaderExpressionType rightExprType = GetExpressionTypeInternal(*node.right);
|
||||
assert(IsBasicType(rightExprType));
|
||||
|
||||
switch (std::get<BasicType>(leftExprType))
|
||||
{
|
||||
case BasicType::Boolean:
|
||||
case BasicType::Float2:
|
||||
case BasicType::Float3:
|
||||
case BasicType::Float4:
|
||||
case BasicType::Int2:
|
||||
case BasicType::Int3:
|
||||
case BasicType::Int4:
|
||||
case BasicType::UInt2:
|
||||
case BasicType::UInt3:
|
||||
case BasicType::UInt4:
|
||||
m_lastExpressionType = std::move(leftExprType);
|
||||
break;
|
||||
|
||||
case BasicType::Float1:
|
||||
case BasicType::Int1:
|
||||
case BasicType::Mat4x4:
|
||||
case BasicType::UInt1:
|
||||
m_lastExpressionType = std::move(rightExprType);
|
||||
break;
|
||||
|
||||
case BasicType::Sampler2D:
|
||||
case BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
case BinaryType::CompNe:
|
||||
m_lastExpressionType = BasicType::Boolean;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(CastExpression& node)
|
||||
{
|
||||
m_lastExpressionType = node.targetType;
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(ConditionalExpression& node)
|
||||
{
|
||||
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.truePath);
|
||||
assert(leftExprType == GetExpressionTypeInternal(*node.falsePath));
|
||||
|
||||
m_lastExpressionType = std::move(leftExprType);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(ConstantExpression& node)
|
||||
{
|
||||
m_lastExpressionType = std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return BasicType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return BasicType::Float1;
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return BasicType::Int1;
|
||||
else if constexpr (std::is_same_v<T, UInt32>)
|
||||
return BasicType::Int1;
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return BasicType::Float2;
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return BasicType::Float3;
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return BasicType::Float4;
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return BasicType::Int2;
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return BasicType::Int3;
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return BasicType::Int4;
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, node.value);
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(IdentifierExpression& node)
|
||||
{
|
||||
auto scopeIt = m_cache->scopeIdByNode.find(&node);
|
||||
if (scopeIt == m_cache->scopeIdByNode.end())
|
||||
throw std::runtime_error("internal error");
|
||||
|
||||
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, node.identifier);
|
||||
if (!identifier || !std::holds_alternative<AstCache::Variable>(identifier->value))
|
||||
throw std::runtime_error("internal error");
|
||||
|
||||
m_lastExpressionType = std::get<AstCache::Variable>(identifier->value).type;
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(IntrinsicExpression& node)
|
||||
{
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
Visit(*node.parameters.front());
|
||||
break;
|
||||
|
||||
case IntrinsicType::DotProduct:
|
||||
m_lastExpressionType = BasicType::Float1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ExpressionTypeVisitor::Visit(SwizzleExpression& node)
|
||||
{
|
||||
const ShaderExpressionType& exprType = GetExpressionTypeInternal(*node.expression);
|
||||
assert(IsBasicType(exprType));
|
||||
|
||||
m_lastExpressionType = static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + node.componentCount - 1);
|
||||
}
|
||||
}
|
||||
@@ -2,15 +2,10 @@
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
ShaderAstVisitor::~ShaderAstVisitor() = default;
|
||||
|
||||
void ShaderAstVisitor::Visit(const ShaderNodes::NodePtr& node)
|
||||
{
|
||||
node->Visit(*this);
|
||||
}
|
||||
AstExpressionVisitor::~AstExpressionVisitor() = default;
|
||||
}
|
||||
15
src/Nazara/Shader/ShaderAstExpressionVisitorExcept.cpp
Normal file
15
src/Nazara/Shader/ShaderAstExpressionVisitorExcept.cpp
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
#define NAZARA_SHADERAST_EXPRESSION(Node) void ExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
|
||||
{ \
|
||||
throw std::runtime_error("unexpected " #Node " node"); \
|
||||
}
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
}
|
||||
@@ -3,16 +3,22 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
|
||||
#include <Nazara/Shader/ShaderAst.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
namespace
|
||||
{
|
||||
template<typename T, typename U>
|
||||
std::unique_ptr<T> static_unique_pointer_cast(std::unique_ptr<U>&& ptr)
|
||||
{
|
||||
return std::unique_ptr<T>(static_cast<T*>(ptr.release()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct is_complete_helper
|
||||
{
|
||||
@@ -29,14 +35,14 @@ namespace Nz
|
||||
inline constexpr bool is_complete_v = is_complete<T>::value;
|
||||
|
||||
|
||||
template<ShaderNodes::BinaryType Type, typename T1, typename T2>
|
||||
template<BinaryType Type, typename T1, typename T2>
|
||||
struct PropagateConstantType;
|
||||
|
||||
// CompEq
|
||||
template<typename T1, typename T2>
|
||||
struct CompEqBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs == rhs);
|
||||
}
|
||||
@@ -46,7 +52,7 @@ namespace Nz
|
||||
struct CompEq;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::CompEq, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::CompEq, T1, T2>
|
||||
{
|
||||
using Op = typename CompEq<T1, T2>;
|
||||
};
|
||||
@@ -55,7 +61,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct CompGeBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs >= rhs);
|
||||
}
|
||||
@@ -65,7 +71,7 @@ namespace Nz
|
||||
struct CompGe;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::CompGe, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::CompGe, T1, T2>
|
||||
{
|
||||
using Op = typename CompGe<T1, T2>;
|
||||
};
|
||||
@@ -74,7 +80,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct CompGtBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs > rhs);
|
||||
}
|
||||
@@ -84,7 +90,7 @@ namespace Nz
|
||||
struct CompGt;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::CompGt, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::CompGt, T1, T2>
|
||||
{
|
||||
using Op = typename CompGt<T1, T2>;
|
||||
};
|
||||
@@ -93,7 +99,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct CompLeBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs <= rhs);
|
||||
}
|
||||
@@ -103,7 +109,7 @@ namespace Nz
|
||||
struct CompLe;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::CompLe, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::CompLe, T1, T2>
|
||||
{
|
||||
using Op = typename CompLe<T1, T2>;
|
||||
};
|
||||
@@ -112,7 +118,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct CompLtBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs < rhs);
|
||||
}
|
||||
@@ -122,7 +128,7 @@ namespace Nz
|
||||
struct CompLt;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::CompLt, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::CompLt, T1, T2>
|
||||
{
|
||||
using Op = typename CompLe<T1, T2>;
|
||||
};
|
||||
@@ -131,7 +137,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct CompNeBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs != rhs);
|
||||
}
|
||||
@@ -141,7 +147,7 @@ namespace Nz
|
||||
struct CompNe;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::CompNe, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::CompNe, T1, T2>
|
||||
{
|
||||
using Op = typename CompNe<T1, T2>;
|
||||
};
|
||||
@@ -150,7 +156,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct AdditionBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs + rhs);
|
||||
}
|
||||
@@ -160,7 +166,7 @@ namespace Nz
|
||||
struct Addition;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::Add, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::Add, T1, T2>
|
||||
{
|
||||
using Op = typename Addition<T1, T2>;
|
||||
};
|
||||
@@ -169,7 +175,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct DivisionBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs / rhs);
|
||||
}
|
||||
@@ -179,7 +185,7 @@ namespace Nz
|
||||
struct Division;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::Divide, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::Divide, T1, T2>
|
||||
{
|
||||
using Op = typename Division<T1, T2>;
|
||||
};
|
||||
@@ -188,7 +194,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct MultiplicationBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs * rhs);
|
||||
}
|
||||
@@ -198,7 +204,7 @@ namespace Nz
|
||||
struct Multiplication;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::Multiply, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::Multiply, T1, T2>
|
||||
{
|
||||
using Op = typename Multiplication<T1, T2>;
|
||||
};
|
||||
@@ -207,7 +213,7 @@ namespace Nz
|
||||
template<typename T1, typename T2>
|
||||
struct SubtractionBase
|
||||
{
|
||||
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
|
||||
{
|
||||
return ShaderBuilder::Constant(lhs - rhs);
|
||||
}
|
||||
@@ -217,7 +223,7 @@ namespace Nz
|
||||
struct Subtraction;
|
||||
|
||||
template<typename T1, typename T2>
|
||||
struct PropagateConstantType<ShaderNodes::BinaryType::Subtract, T1, T2>
|
||||
struct PropagateConstantType<BinaryType::Subtract, T1, T2>
|
||||
{
|
||||
using Op = typename Subtraction<T1, T2>;
|
||||
};
|
||||
@@ -375,92 +381,89 @@ namespace Nz
|
||||
#undef EnableOptimisation
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement)
|
||||
StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
|
||||
{
|
||||
m_shaderAst = nullptr;
|
||||
|
||||
return CloneStatement(statement);
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions)
|
||||
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
|
||||
{
|
||||
m_shaderAst = &shader;
|
||||
m_enabledConditions = enabledConditions;
|
||||
|
||||
return CloneStatement(statement);
|
||||
}
|
||||
|
||||
void ShaderAstOptimizer::Visit(ShaderNodes::BinaryOp& node)
|
||||
void AstOptimizer::Visit(BinaryExpression& node)
|
||||
{
|
||||
auto lhs = CloneExpression(node.left);
|
||||
auto rhs = CloneExpression(node.right);
|
||||
|
||||
if (lhs->GetType() == ShaderNodes::NodeType::Constant && rhs->GetType() == ShaderNodes::NodeType::Constant)
|
||||
if (lhs->GetType() == NodeType::ConstantExpression && rhs->GetType() == NodeType::ConstantExpression)
|
||||
{
|
||||
auto lhsConstant = std::static_pointer_cast<ShaderNodes::Constant>(lhs);
|
||||
auto rhsConstant = std::static_pointer_cast<ShaderNodes::Constant>(rhs);
|
||||
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
|
||||
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
|
||||
|
||||
switch (node.op)
|
||||
{
|
||||
case ShaderNodes::BinaryType::Add:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::Add>(lhsConstant, rhsConstant);
|
||||
case BinaryType::Add:
|
||||
return PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::Subtract:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::Subtract>(lhsConstant, rhsConstant);
|
||||
case BinaryType::Subtract:
|
||||
return PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::Multiply:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::Multiply>(lhsConstant, rhsConstant);
|
||||
case BinaryType::Multiply:
|
||||
return PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::Divide:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::Divide>(lhsConstant, rhsConstant);
|
||||
case BinaryType::Divide:
|
||||
return PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::CompEq:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::CompEq>(lhsConstant, rhsConstant);
|
||||
case BinaryType::CompEq:
|
||||
return PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::CompGe:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::CompGe>(lhsConstant, rhsConstant);
|
||||
case BinaryType::CompGe:
|
||||
return PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::CompGt:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::CompGt>(lhsConstant, rhsConstant);
|
||||
case BinaryType::CompGt:
|
||||
return PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::CompLe:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::CompLe>(lhsConstant, rhsConstant);
|
||||
case BinaryType::CompLe:
|
||||
return PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::CompLt:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::CompLt>(lhsConstant, rhsConstant);
|
||||
case BinaryType::CompLt:
|
||||
return PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
|
||||
case ShaderNodes::BinaryType::CompNe:
|
||||
return PropagateConstant<ShaderNodes::BinaryType::CompNe>(lhsConstant, rhsConstant);
|
||||
case BinaryType::CompNe:
|
||||
return PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
|
||||
}
|
||||
}
|
||||
|
||||
ShaderAstCloner::Visit(node);
|
||||
AstCloner::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstOptimizer::Visit(ShaderNodes::Branch& node)
|
||||
void AstOptimizer::Visit(BranchStatement& node)
|
||||
{
|
||||
std::vector<ShaderNodes::Branch::ConditionalStatement> statements;
|
||||
ShaderNodes::StatementPtr elseStatement;
|
||||
std::vector<BranchStatement::ConditionalStatement> statements;
|
||||
StatementPtr elseStatement;
|
||||
|
||||
for (auto& condStatement : node.condStatements)
|
||||
{
|
||||
auto cond = CloneExpression(condStatement.condition);
|
||||
|
||||
if (cond->GetType() == ShaderNodes::NodeType::Constant)
|
||||
if (cond->GetType() == NodeType::ConstantExpression)
|
||||
{
|
||||
auto constant = std::static_pointer_cast<ShaderNodes::Constant>(cond);
|
||||
auto& constant = static_cast<ConstantExpression&>(*cond);
|
||||
|
||||
assert(IsBasicType(cond->GetExpressionType()));
|
||||
assert(std::get<ShaderNodes::BasicType>(cond->GetExpressionType()) == ShaderNodes::BasicType::Boolean);
|
||||
assert(IsBasicType(GetExpressionType(constant)));
|
||||
assert(std::get<BasicType>(GetExpressionType(constant)) == BasicType::Boolean);
|
||||
|
||||
bool cValue = std::get<bool>(constant->value);
|
||||
bool cValue = std::get<bool>(constant.value);
|
||||
if (!cValue)
|
||||
continue;
|
||||
|
||||
if (statements.empty())
|
||||
{
|
||||
// First condition is true, dismiss the branch
|
||||
Visit(condStatement.statement);
|
||||
condStatement.statement->Visit(*this);
|
||||
return;
|
||||
}
|
||||
else
|
||||
@@ -482,47 +485,54 @@ namespace Nz
|
||||
{
|
||||
// All conditions have been removed, replace by else statement or no-op
|
||||
if (node.elseStatement)
|
||||
return Visit(node.elseStatement);
|
||||
{
|
||||
node.elseStatement->Visit(*this);
|
||||
return;
|
||||
}
|
||||
else
|
||||
return PushStatement(ShaderNodes::NoOp::Build());
|
||||
return PushStatement(ShaderBuilder::NoOp());
|
||||
}
|
||||
|
||||
if (!elseStatement)
|
||||
elseStatement = CloneStatement(node.elseStatement);
|
||||
|
||||
PushStatement(ShaderNodes::Branch::Build(std::move(statements), std::move(elseStatement)));
|
||||
PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)));
|
||||
}
|
||||
|
||||
void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
void AstOptimizer::Visit(ConditionalExpression& node)
|
||||
{
|
||||
if (!m_shaderAst)
|
||||
return AstCloner::Visit(node);
|
||||
|
||||
/*if (!m_shaderAst)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
assert(conditionIndex != InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
Visit(node.falsePath);*/
|
||||
}
|
||||
|
||||
void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
void AstOptimizer::Visit(ConditionalStatement& node)
|
||||
{
|
||||
if (!m_shaderAst)
|
||||
return AstCloner::Visit(node);
|
||||
|
||||
/*if (!m_shaderAst)
|
||||
return ShaderAstCloner::Visit(node);
|
||||
|
||||
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
assert(conditionIndex != InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
|
||||
Visit(node.statement);
|
||||
Visit(node.statement);*/
|
||||
}
|
||||
|
||||
template<ShaderNodes::BinaryType Type>
|
||||
void ShaderAstOptimizer::PropagateConstant(const std::shared_ptr<ShaderNodes::Constant>& lhs, const std::shared_ptr<ShaderNodes::Constant>& rhs)
|
||||
template<BinaryType Type>
|
||||
void AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
|
||||
{
|
||||
ShaderNodes::ExpressionPtr optimized;
|
||||
ExpressionPtr optimized;
|
||||
std::visit([&](auto&& arg1)
|
||||
{
|
||||
using T1 = std::decay_t<decltype(arg1)>;
|
||||
@@ -543,8 +553,8 @@ namespace Nz
|
||||
}, lhs->value);
|
||||
|
||||
if (optimized)
|
||||
PushExpression(optimized);
|
||||
PushExpression(std::move(optimized));
|
||||
else
|
||||
PushExpression(ShaderNodes::BinaryOp::Build(Type, lhs, rhs));
|
||||
PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,116 +5,121 @@
|
||||
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AccessMember& node)
|
||||
void AstRecursiveVisitor::Visit(AccessMemberExpression& node)
|
||||
{
|
||||
Visit(node.structExpr);
|
||||
node.structExpr->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AssignOp& node)
|
||||
void AstRecursiveVisitor::Visit(AssignExpression& node)
|
||||
{
|
||||
Visit(node.left);
|
||||
Visit(node.right);
|
||||
node.left->Visit(*this);
|
||||
node.right->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::BinaryOp& node)
|
||||
void AstRecursiveVisitor::Visit(BinaryExpression& node)
|
||||
{
|
||||
Visit(node.left);
|
||||
Visit(node.right);
|
||||
node.left->Visit(*this);
|
||||
node.right->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Branch& node)
|
||||
{
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
Visit(cond.condition);
|
||||
Visit(cond.statement);
|
||||
}
|
||||
|
||||
if (node.elseStatement)
|
||||
Visit(node.elseStatement);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Cast& node)
|
||||
void AstRecursiveVisitor::Visit(CastExpression& node)
|
||||
{
|
||||
for (auto& expr : node.expressions)
|
||||
{
|
||||
if (!expr)
|
||||
break;
|
||||
|
||||
Visit(expr);
|
||||
expr->Visit(*this);
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
void AstRecursiveVisitor::Visit(ConditionalExpression& node)
|
||||
{
|
||||
Visit(node.truePath);
|
||||
Visit(node.falsePath);
|
||||
node.truePath->Visit(*this);
|
||||
node.falsePath->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
Visit(node.statement);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/)
|
||||
void AstRecursiveVisitor::Visit(ConstantExpression& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::DeclareVariable& node)
|
||||
{
|
||||
if (node.expression)
|
||||
Visit(node.expression);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Discard& /*node*/)
|
||||
void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ExpressionStatement& node)
|
||||
{
|
||||
Visit(node.expression);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Identifier& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::IntrinsicCall& node)
|
||||
void AstRecursiveVisitor::Visit(IntrinsicExpression& node)
|
||||
{
|
||||
for (auto& param : node.parameters)
|
||||
Visit(param);
|
||||
param->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::NoOp& /*node*/)
|
||||
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
|
||||
{
|
||||
node.expression->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(BranchStatement& node)
|
||||
{
|
||||
for (auto& cond : node.condStatements)
|
||||
{
|
||||
cond.condition->Visit(*this);
|
||||
cond.statement->Visit(*this);
|
||||
}
|
||||
|
||||
if (node.elseStatement)
|
||||
node.elseStatement->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(ConditionalStatement& node)
|
||||
{
|
||||
node.statement->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(DeclareFunctionStatement& node)
|
||||
{
|
||||
for (auto& statement : node.statements)
|
||||
statement->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(DeclareStructStatement& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ReturnStatement& node)
|
||||
void AstRecursiveVisitor::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
if (node.returnExpr)
|
||||
Visit(node.returnExpr);
|
||||
if (node.initialExpression)
|
||||
node.initialExpression->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Sample2D& node)
|
||||
void AstRecursiveVisitor::Visit(DiscardStatement& /*node*/)
|
||||
{
|
||||
Visit(node.sampler);
|
||||
Visit(node.coordinates);
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::StatementBlock& node)
|
||||
void AstRecursiveVisitor::Visit(ExpressionStatement& node)
|
||||
{
|
||||
node.expression->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(MultiStatement& node)
|
||||
{
|
||||
for (auto& statement : node.statements)
|
||||
Visit(statement);
|
||||
statement->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::SwizzleOp& node)
|
||||
void AstRecursiveVisitor::Visit(NoOpStatement& /*node*/)
|
||||
{
|
||||
Visit(node.expression);
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(ReturnStatement& node)
|
||||
{
|
||||
if (node.returnExpr)
|
||||
node.returnExpr->Visit(*this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,221 +3,74 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstSerializer.hpp>
|
||||
#include <Nazara/Shader/ShaderVarVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
namespace
|
||||
{
|
||||
constexpr UInt32 s_magicNumber = 0x4E534852;
|
||||
constexpr UInt32 s_currentVersion = 1;
|
||||
|
||||
class ShaderSerializerVisitor : public ShaderAstVisitor, public ShaderVarVisitor
|
||||
class ShaderSerializerVisitor : public AstExpressionVisitor, public AstStatementVisitor
|
||||
{
|
||||
public:
|
||||
ShaderSerializerVisitor(ShaderAstSerializerBase& serializer) :
|
||||
ShaderSerializerVisitor(AstSerializerBase& serializer) :
|
||||
m_serializer(serializer)
|
||||
{
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::AccessMember& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::AssignOp& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::BinaryOp& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Branch& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Cast& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Constant& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Discard& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ExpressionStatement& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Identifier& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::IntrinsicCall& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::NoOp& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ReturnStatement& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Sample2D& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::StatementBlock& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::SwizzleOp& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
|
||||
void Visit(ShaderNodes::BuiltinVariable& var) override
|
||||
{
|
||||
Serialize(var);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::InputVariable& var) override
|
||||
{
|
||||
Serialize(var);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::LocalVariable& var) override
|
||||
{
|
||||
Serialize(var);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::OutputVariable& var) override
|
||||
{
|
||||
Serialize(var);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ParameterVariable& var) override
|
||||
{
|
||||
Serialize(var);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::UniformVariable& var) override
|
||||
{
|
||||
Serialize(var);
|
||||
#define NAZARA_SHADERAST_NODE(Node) void Visit(Node& node) override \
|
||||
{ \
|
||||
m_serializer.Serialize(node); \
|
||||
}
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
void Serialize(const T& node)
|
||||
{
|
||||
// I know const_cast is evil but I don't have a better solution here (it's not used to write)
|
||||
m_serializer.Serialize(const_cast<T&>(node));
|
||||
}
|
||||
|
||||
ShaderAstSerializerBase& m_serializer;
|
||||
AstSerializerBase& m_serializer;
|
||||
};
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::AccessMember& node)
|
||||
void AstSerializerBase::Serialize(AccessMemberExpression& node)
|
||||
{
|
||||
Node(node.structExpr);
|
||||
Type(node.exprType);
|
||||
|
||||
Container(node.memberIndices);
|
||||
for (std::size_t& index : node.memberIndices)
|
||||
SizeT(index);
|
||||
Container(node.memberIdentifiers);
|
||||
for (std::string& identifier : node.memberIdentifiers)
|
||||
Value(identifier);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::AssignOp& node)
|
||||
void AstSerializerBase::Serialize(AssignExpression& node)
|
||||
{
|
||||
Enum(node.op);
|
||||
Node(node.left);
|
||||
Node(node.right);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::BinaryOp& node)
|
||||
void AstSerializerBase::Serialize(BinaryExpression& node)
|
||||
{
|
||||
Enum(node.op);
|
||||
Node(node.left);
|
||||
Node(node.right);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::Branch& node)
|
||||
void AstSerializerBase::Serialize(CastExpression& node)
|
||||
{
|
||||
Container(node.condStatements);
|
||||
for (auto& condStatement : node.condStatements)
|
||||
{
|
||||
Node(condStatement.condition);
|
||||
Node(condStatement.statement);
|
||||
}
|
||||
|
||||
Node(node.elseStatement);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node)
|
||||
{
|
||||
Enum(node.entry);
|
||||
Type(node.type);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::Cast& node)
|
||||
{
|
||||
Enum(node.exprType);
|
||||
Enum(node.targetType);
|
||||
for (auto& expr : node.expressions)
|
||||
Node(expr);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalExpression& node)
|
||||
void AstSerializerBase::Serialize(ConditionalExpression& node)
|
||||
{
|
||||
Value(node.conditionName);
|
||||
Node(node.truePath);
|
||||
Node(node.falsePath);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalStatement& node)
|
||||
{
|
||||
Value(node.conditionName);
|
||||
Node(node.statement);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::Constant& node)
|
||||
|
||||
void AstSerializerBase::Serialize(ConstantExpression& node)
|
||||
{
|
||||
UInt32 typeIndex;
|
||||
if (IsWriting())
|
||||
@@ -251,28 +104,19 @@ namespace Nz
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::DeclareVariable& node)
|
||||
void AstSerializerBase::Serialize(DeclareVariableStatement& node)
|
||||
{
|
||||
Variable(node.variable);
|
||||
Node(node.expression);
|
||||
Value(node.varName);
|
||||
Type(node.varType);
|
||||
Node(node.initialExpression);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::Discard& /*node*/)
|
||||
void AstSerializerBase::Serialize(IdentifierExpression& node)
|
||||
{
|
||||
/* Nothing to do */
|
||||
Value(node.identifier);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::ExpressionStatement& node)
|
||||
{
|
||||
Node(node.expression);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::Identifier& node)
|
||||
{
|
||||
Variable(node.var);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::IntrinsicCall& node)
|
||||
void AstSerializerBase::Serialize(IntrinsicExpression& node)
|
||||
{
|
||||
Enum(node.intrinsic);
|
||||
Container(node.parameters);
|
||||
@@ -280,36 +124,7 @@ namespace Nz
|
||||
Node(param);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::NamedVariable& node)
|
||||
{
|
||||
Value(node.name);
|
||||
Type(node.type);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::NoOp& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::ReturnStatement& node)
|
||||
{
|
||||
Node(node.returnExpr);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::Sample2D& node)
|
||||
{
|
||||
Node(node.sampler);
|
||||
Node(node.coordinates);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::StatementBlock& node)
|
||||
{
|
||||
Container(node.statements);
|
||||
for (auto& statement : node.statements)
|
||||
Node(statement);
|
||||
}
|
||||
|
||||
void ShaderAstSerializerBase::Serialize(ShaderNodes::SwizzleOp& node)
|
||||
void AstSerializerBase::Serialize(SwizzleExpression& node)
|
||||
{
|
||||
SizeT(node.componentCount);
|
||||
Node(node.expression);
|
||||
@@ -319,100 +134,85 @@ namespace Nz
|
||||
}
|
||||
|
||||
|
||||
void ShaderAstSerializer::Serialize(const ShaderAst& shader)
|
||||
void AstSerializerBase::Serialize(BranchStatement& node)
|
||||
{
|
||||
Container(node.condStatements);
|
||||
for (auto& condStatement : node.condStatements)
|
||||
{
|
||||
Node(condStatement.condition);
|
||||
Node(condStatement.statement);
|
||||
}
|
||||
|
||||
Node(node.elseStatement);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(ConditionalStatement& node)
|
||||
{
|
||||
Value(node.conditionName);
|
||||
Node(node.statement);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
|
||||
{
|
||||
Value(node.name);
|
||||
Type(node.returnType);
|
||||
|
||||
Container(node.parameters);
|
||||
for (auto& parameter : node.parameters)
|
||||
{
|
||||
Value(parameter.name);
|
||||
Type(parameter.type);
|
||||
}
|
||||
|
||||
Container(node.statements);
|
||||
for (auto& statement : node.statements)
|
||||
Node(statement);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(DeclareStructStatement& node)
|
||||
{
|
||||
Value(node.description.name);
|
||||
|
||||
Container(node.description.members);
|
||||
for (auto& member : node.description.members)
|
||||
{
|
||||
Value(member.name);
|
||||
Type(member.type);
|
||||
}
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(DiscardStatement& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(ExpressionStatement& node)
|
||||
{
|
||||
Node(node.expression);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(MultiStatement& node)
|
||||
{
|
||||
Container(node.statements);
|
||||
for (auto& statement : node.statements)
|
||||
Node(statement);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(NoOpStatement& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(ReturnStatement& node)
|
||||
{
|
||||
Node(node.returnExpr);
|
||||
}
|
||||
|
||||
void ShaderAstSerializer::Serialize(StatementPtr& shader)
|
||||
{
|
||||
m_stream << s_magicNumber << s_currentVersion;
|
||||
|
||||
m_stream << UInt32(shader.GetStage());
|
||||
|
||||
auto SerializeType = [&](const ShaderExpressionType& type)
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
|
||||
{
|
||||
m_stream << UInt8(0);
|
||||
m_stream << UInt32(arg);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, std::string>)
|
||||
{
|
||||
m_stream << UInt8(1);
|
||||
m_stream << arg;
|
||||
}
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, type);
|
||||
};
|
||||
|
||||
auto SerializeInputOutput = [&](auto& inout)
|
||||
{
|
||||
m_stream << UInt32(inout.size());
|
||||
for (const auto& data : inout)
|
||||
{
|
||||
m_stream << data.name;
|
||||
SerializeType(data.type);
|
||||
|
||||
m_stream << data.locationIndex.has_value();
|
||||
if (data.locationIndex)
|
||||
m_stream << UInt32(data.locationIndex.value());
|
||||
}
|
||||
};
|
||||
|
||||
// Conditions
|
||||
m_stream << UInt32(shader.GetConditionCount());
|
||||
for (const auto& cond : shader.GetConditions())
|
||||
m_stream << cond.name;
|
||||
|
||||
// Structs
|
||||
m_stream << UInt32(shader.GetStructCount());
|
||||
for (const auto& s : shader.GetStructs())
|
||||
{
|
||||
m_stream << s.name;
|
||||
m_stream << UInt32(s.members.size());
|
||||
for (const auto& member : s.members)
|
||||
{
|
||||
m_stream << member.name;
|
||||
SerializeType(member.type);
|
||||
}
|
||||
}
|
||||
|
||||
// Inputs / Outputs
|
||||
SerializeInputOutput(shader.GetInputs());
|
||||
SerializeInputOutput(shader.GetOutputs());
|
||||
|
||||
// Uniforms
|
||||
m_stream << UInt32(shader.GetUniformCount());
|
||||
for (const auto& uniform : shader.GetUniforms())
|
||||
{
|
||||
m_stream << uniform.name;
|
||||
SerializeType(uniform.type);
|
||||
|
||||
m_stream << uniform.bindingIndex.has_value();
|
||||
if (uniform.bindingIndex)
|
||||
m_stream << UInt32(uniform.bindingIndex.value());
|
||||
|
||||
m_stream << uniform.memoryLayout.has_value();
|
||||
if (uniform.memoryLayout)
|
||||
m_stream << UInt32(uniform.memoryLayout.value());
|
||||
}
|
||||
|
||||
// Functions
|
||||
m_stream << UInt32(shader.GetFunctionCount());
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
m_stream << func.name;
|
||||
SerializeType(func.returnType);
|
||||
|
||||
m_stream << UInt32(func.parameters.size());
|
||||
for (const auto& param : func.parameters)
|
||||
{
|
||||
m_stream << param.name;
|
||||
SerializeType(param.type);
|
||||
}
|
||||
|
||||
Node(func.statement);
|
||||
}
|
||||
Node(shader);
|
||||
|
||||
m_stream.FlushBits();
|
||||
}
|
||||
@@ -422,9 +222,21 @@ namespace Nz
|
||||
return true;
|
||||
}
|
||||
|
||||
void ShaderAstSerializer::Node(ShaderNodes::NodePtr& node)
|
||||
void ShaderAstSerializer::Node(ExpressionPtr& node)
|
||||
{
|
||||
ShaderNodes::NodeType nodeType = (node) ? node->GetType() : ShaderNodes::NodeType::None;
|
||||
NodeType nodeType = (node) ? node->GetType() : NodeType::None;
|
||||
m_stream << static_cast<Int32>(nodeType);
|
||||
|
||||
if (node)
|
||||
{
|
||||
ShaderSerializerVisitor visitor(*this);
|
||||
node->Visit(visitor);
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderAstSerializer::Node(StatementPtr& node)
|
||||
{
|
||||
NodeType nodeType = (node) ? node->GetType() : NodeType::None;
|
||||
m_stream << static_cast<Int32>(nodeType);
|
||||
|
||||
if (node)
|
||||
@@ -439,7 +251,7 @@ namespace Nz
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
|
||||
if constexpr (std::is_same_v<T, BasicType>)
|
||||
{
|
||||
m_stream << UInt8(0);
|
||||
m_stream << UInt32(arg);
|
||||
@@ -454,11 +266,6 @@ namespace Nz
|
||||
}, type);
|
||||
}
|
||||
|
||||
void ShaderAstSerializer::Node(const ShaderNodes::NodePtr& node)
|
||||
{
|
||||
Node(const_cast<ShaderNodes::NodePtr&>(node)); //< Yes const_cast is ugly but it won't be used for writing
|
||||
}
|
||||
|
||||
void ShaderAstSerializer::Value(bool& val)
|
||||
{
|
||||
m_stream << val;
|
||||
@@ -529,19 +336,7 @@ namespace Nz
|
||||
m_stream << val;
|
||||
}
|
||||
|
||||
void ShaderAstSerializer::Variable(ShaderNodes::VariablePtr& var)
|
||||
{
|
||||
ShaderNodes::VariableType nodeType = (var) ? var->GetType() : ShaderNodes::VariableType::None;
|
||||
m_stream << static_cast<Int32>(nodeType);
|
||||
|
||||
if (var)
|
||||
{
|
||||
ShaderSerializerVisitor visitor(*this);
|
||||
var->Visit(visitor);
|
||||
}
|
||||
}
|
||||
|
||||
ShaderAst ShaderAstUnserializer::Unserialize()
|
||||
StatementPtr ShaderAstUnserializer::Unserialize()
|
||||
{
|
||||
UInt32 magicNumber;
|
||||
UInt32 version;
|
||||
@@ -553,122 +348,13 @@ namespace Nz
|
||||
if (version > s_currentVersion)
|
||||
throw std::runtime_error("unsupported version");
|
||||
|
||||
UInt32 shaderStage;
|
||||
m_stream >> shaderStage;
|
||||
StatementPtr node;
|
||||
|
||||
ShaderAst shader(static_cast<ShaderStageType>(shaderStage));
|
||||
Node(node);
|
||||
if (!node)
|
||||
throw std::runtime_error("functions can only have statements");
|
||||
|
||||
// Conditions
|
||||
UInt32 conditionCount;
|
||||
m_stream >> conditionCount;
|
||||
for (UInt32 i = 0; i < conditionCount; ++i)
|
||||
{
|
||||
std::string conditionName;
|
||||
Value(conditionName);
|
||||
|
||||
shader.AddCondition(std::move(conditionName));
|
||||
}
|
||||
|
||||
// Structs
|
||||
UInt32 structCount;
|
||||
m_stream >> structCount;
|
||||
for (UInt32 i = 0; i < structCount; ++i)
|
||||
{
|
||||
std::string structName;
|
||||
std::vector<ShaderAst::StructMember> members;
|
||||
|
||||
Value(structName);
|
||||
Container(members);
|
||||
|
||||
for (auto& member : members)
|
||||
{
|
||||
Value(member.name);
|
||||
Type(member.type);
|
||||
}
|
||||
|
||||
shader.AddStruct(std::move(structName), std::move(members));
|
||||
}
|
||||
|
||||
// Inputs
|
||||
UInt32 inputCount;
|
||||
m_stream >> inputCount;
|
||||
for (UInt32 i = 0; i < inputCount; ++i)
|
||||
{
|
||||
std::string inputName;
|
||||
ShaderExpressionType inputType;
|
||||
std::optional<std::size_t> location;
|
||||
|
||||
Value(inputName);
|
||||
Type(inputType);
|
||||
OptVal(location);
|
||||
|
||||
shader.AddInput(std::move(inputName), std::move(inputType), location);
|
||||
}
|
||||
|
||||
// Outputs
|
||||
UInt32 outputCount;
|
||||
m_stream >> outputCount;
|
||||
for (UInt32 i = 0; i < outputCount; ++i)
|
||||
{
|
||||
std::string outputName;
|
||||
ShaderExpressionType outputType;
|
||||
std::optional<std::size_t> location;
|
||||
|
||||
Value(outputName);
|
||||
Type(outputType);
|
||||
OptVal(location);
|
||||
|
||||
shader.AddOutput(std::move(outputName), std::move(outputType), location);
|
||||
}
|
||||
|
||||
// Uniforms
|
||||
UInt32 uniformCount;
|
||||
m_stream >> uniformCount;
|
||||
for (UInt32 i = 0; i < uniformCount; ++i)
|
||||
{
|
||||
std::string name;
|
||||
ShaderExpressionType type;
|
||||
std::optional<std::size_t> binding;
|
||||
std::optional<ShaderNodes::MemoryLayout> memLayout;
|
||||
|
||||
Value(name);
|
||||
Type(type);
|
||||
OptVal(binding);
|
||||
OptEnum(memLayout);
|
||||
|
||||
shader.AddUniform(std::move(name), std::move(type), std::move(binding), std::move(memLayout));
|
||||
}
|
||||
|
||||
// Functions
|
||||
UInt32 funcCount;
|
||||
m_stream >> funcCount;
|
||||
for (UInt32 i = 0; i < funcCount; ++i)
|
||||
{
|
||||
std::string name;
|
||||
ShaderExpressionType retType;
|
||||
std::vector<ShaderAst::FunctionParameter> parameters;
|
||||
|
||||
Value(name);
|
||||
Type(retType);
|
||||
|
||||
Container(parameters);
|
||||
for (auto& param : parameters)
|
||||
{
|
||||
Value(param.name);
|
||||
Type(param.type);
|
||||
}
|
||||
|
||||
ShaderNodes::NodePtr node;
|
||||
Node(node);
|
||||
if (!node || !node->IsStatement())
|
||||
throw std::runtime_error("functions can only have statements");
|
||||
|
||||
ShaderNodes::StatementPtr statement = std::static_pointer_cast<ShaderNodes::Statement>(node);
|
||||
|
||||
shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), std::move(retType));
|
||||
}
|
||||
|
||||
return shader;
|
||||
return node;
|
||||
}
|
||||
|
||||
bool ShaderAstUnserializer::IsWriting() const
|
||||
@@ -676,41 +362,50 @@ namespace Nz
|
||||
return false;
|
||||
}
|
||||
|
||||
void ShaderAstUnserializer::Node(ShaderNodes::NodePtr& node)
|
||||
void ShaderAstUnserializer::Node(ExpressionPtr& node)
|
||||
{
|
||||
Int32 nodeTypeInt;
|
||||
m_stream >> nodeTypeInt;
|
||||
|
||||
if (nodeTypeInt < static_cast<Int32>(ShaderNodes::NodeType::None) || nodeTypeInt > static_cast<Int32>(ShaderNodes::NodeType::Max))
|
||||
if (nodeTypeInt < static_cast<Int32>(NodeType::None) || nodeTypeInt > static_cast<Int32>(NodeType::Max))
|
||||
throw std::runtime_error("invalid node type");
|
||||
|
||||
ShaderNodes::NodeType nodeType = static_cast<ShaderNodes::NodeType>(nodeTypeInt);
|
||||
|
||||
#define HandleType(Type) case ShaderNodes::NodeType:: Type : node = std::make_shared<ShaderNodes:: Type>(); break
|
||||
NodeType nodeType = static_cast<NodeType>(nodeTypeInt);
|
||||
switch (nodeType)
|
||||
{
|
||||
case ShaderNodes::NodeType::None: break;
|
||||
case NodeType::None: break;
|
||||
|
||||
HandleType(AccessMember);
|
||||
HandleType(AssignOp);
|
||||
HandleType(BinaryOp);
|
||||
HandleType(Branch);
|
||||
HandleType(Cast);
|
||||
HandleType(Constant);
|
||||
HandleType(ConditionalExpression);
|
||||
HandleType(ConditionalStatement);
|
||||
HandleType(DeclareVariable);
|
||||
HandleType(Discard);
|
||||
HandleType(ExpressionStatement);
|
||||
HandleType(Identifier);
|
||||
HandleType(IntrinsicCall);
|
||||
HandleType(NoOp);
|
||||
HandleType(ReturnStatement);
|
||||
HandleType(Sample2D);
|
||||
HandleType(SwizzleOp);
|
||||
HandleType(StatementBlock);
|
||||
#define NAZARA_SHADERAST_EXPRESSION(Node) case NodeType:: Node : node = std::make_unique<Node>(); break;
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
|
||||
default: throw std::runtime_error("unexpected node type");
|
||||
}
|
||||
|
||||
if (node)
|
||||
{
|
||||
ShaderSerializerVisitor visitor(*this);
|
||||
node->Visit(visitor);
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderAstUnserializer::Node(StatementPtr& node)
|
||||
{
|
||||
Int32 nodeTypeInt;
|
||||
m_stream >> nodeTypeInt;
|
||||
|
||||
if (nodeTypeInt < static_cast<Int32>(NodeType::None) || nodeTypeInt > static_cast<Int32>(NodeType::Max))
|
||||
throw std::runtime_error("invalid node type");
|
||||
|
||||
NodeType nodeType = static_cast<NodeType>(nodeTypeInt);
|
||||
switch (nodeType)
|
||||
{
|
||||
case NodeType::None: break;
|
||||
|
||||
#define NAZARA_SHADERAST_STATEMENT(Node) case NodeType:: Node : node = std::make_unique<Node>(); break;
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
|
||||
default: throw std::runtime_error("unexpected node type");
|
||||
}
|
||||
#undef HandleType
|
||||
|
||||
if (node)
|
||||
{
|
||||
@@ -728,7 +423,7 @@ namespace Nz
|
||||
{
|
||||
case 0: //< Primitive
|
||||
{
|
||||
ShaderNodes::BasicType exprType;
|
||||
BasicType exprType;
|
||||
Enum(exprType);
|
||||
|
||||
type = exprType;
|
||||
@@ -819,36 +514,8 @@ namespace Nz
|
||||
m_stream >> val;
|
||||
}
|
||||
|
||||
void ShaderAstUnserializer::Variable(ShaderNodes::VariablePtr& var)
|
||||
{
|
||||
Int32 nodeTypeInt;
|
||||
m_stream >> nodeTypeInt;
|
||||
|
||||
ShaderNodes::VariableType nodeType = static_cast<ShaderNodes:: VariableType>(nodeTypeInt);
|
||||
|
||||
#define HandleType(Type) case ShaderNodes::VariableType:: Type : var = std::make_shared<ShaderNodes::Type>(); break
|
||||
switch (nodeType)
|
||||
{
|
||||
case ShaderNodes::VariableType::None: break;
|
||||
|
||||
HandleType(BuiltinVariable);
|
||||
HandleType(InputVariable);
|
||||
HandleType(LocalVariable);
|
||||
HandleType(ParameterVariable);
|
||||
HandleType(OutputVariable);
|
||||
HandleType(UniformVariable);
|
||||
}
|
||||
#undef HandleType
|
||||
|
||||
if (var)
|
||||
{
|
||||
ShaderSerializerVisitor visitor(*this);
|
||||
var->Visit(visitor);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ByteArray SerializeShader(const ShaderAst& shader)
|
||||
ByteArray SerializeShader(StatementPtr& shader)
|
||||
{
|
||||
ByteArray byteArray;
|
||||
ByteStream stream(&byteArray, OpenModeFlags(OpenMode_WriteOnly));
|
||||
@@ -859,7 +526,7 @@ namespace Nz
|
||||
return byteArray;
|
||||
}
|
||||
|
||||
ShaderAst UnserializeShader(ByteStream& stream)
|
||||
StatementPtr UnserializeShader(ByteStream& stream)
|
||||
{
|
||||
ShaderAstUnserializer unserializer(stream);
|
||||
return unserializer.Unserialize();
|
||||
|
||||
@@ -2,15 +2,10 @@
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderVarVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
ShaderVarVisitor::~ShaderVarVisitor() = default;
|
||||
|
||||
void ShaderVarVisitor::Visit(const ShaderNodes::VariablePtr& node)
|
||||
{
|
||||
node->Visit(*this);
|
||||
}
|
||||
AstStatementVisitor::~AstStatementVisitor() = default;
|
||||
}
|
||||
15
src/Nazara/Shader/ShaderAstStatementVisitorExcept.cpp
Normal file
15
src/Nazara/Shader/ShaderAstStatementVisitorExcept.cpp
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
#define NAZARA_SHADERAST_STATEMENT(Node) void StatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
|
||||
{ \
|
||||
throw std::runtime_error("unexpected " #Node " node"); \
|
||||
}
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
}
|
||||
@@ -5,69 +5,65 @@
|
||||
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
ShaderNodes::ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression)
|
||||
ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(Expression& expression)
|
||||
{
|
||||
Visit(expression);
|
||||
expression.Visit(*this);
|
||||
return m_expressionCategory;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::AccessMember& node)
|
||||
void ShaderAstValueCategory::Visit(AccessMemberExpression& node)
|
||||
{
|
||||
Visit(node.structExpr);
|
||||
node.structExpr->Visit(*this);
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::AssignOp& node)
|
||||
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::BinaryOp& node)
|
||||
void ShaderAstValueCategory::Visit(BinaryExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::Cast& node)
|
||||
void ShaderAstValueCategory::Visit(CastExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
void ShaderAstValueCategory::Visit(ConditionalExpression& node)
|
||||
{
|
||||
Visit(node.truePath);
|
||||
ShaderNodes::ExpressionCategory trueExprCategory = m_expressionCategory;
|
||||
Visit(node.falsePath);
|
||||
ShaderNodes::ExpressionCategory falseExprCategory = m_expressionCategory;
|
||||
node.truePath->Visit(*this);
|
||||
ExpressionCategory trueExprCategory = m_expressionCategory;
|
||||
|
||||
if (trueExprCategory == ShaderNodes::ExpressionCategory::RValue || falseExprCategory == ShaderNodes::ExpressionCategory::RValue)
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
|
||||
node.falsePath->Visit(*this);
|
||||
ExpressionCategory falseExprCategory = m_expressionCategory;
|
||||
|
||||
if (trueExprCategory == ExpressionCategory::RValue || falseExprCategory == ExpressionCategory::RValue)
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
else
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::LValue;
|
||||
m_expressionCategory = ExpressionCategory::LValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::Constant& node)
|
||||
void ShaderAstValueCategory::Visit(ConstantExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::Identifier& node)
|
||||
void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::LValue;
|
||||
m_expressionCategory = ExpressionCategory::LValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::IntrinsicCall& node)
|
||||
void ShaderAstValueCategory::Visit(IntrinsicExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::Sample2D& node)
|
||||
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
|
||||
{
|
||||
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(ShaderNodes::SwizzleOp& node)
|
||||
{
|
||||
Visit(node.expression);
|
||||
node.expression->Visit(*this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,48 +4,40 @@
|
||||
|
||||
#include <Nazara/Shader/ShaderAstValidator.hpp>
|
||||
#include <Nazara/Core/CallOnExit.hpp>
|
||||
#include <Nazara/Shader/ShaderAst.hpp>
|
||||
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
||||
#include <Nazara/Shader/ShaderVariables.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <vector>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
struct AstError
|
||||
{
|
||||
std::string errMsg;
|
||||
};
|
||||
|
||||
struct ShaderAstValidator::Context
|
||||
struct AstValidator::Context
|
||||
{
|
||||
struct Local
|
||||
{
|
||||
std::string name;
|
||||
ShaderExpressionType type;
|
||||
};
|
||||
|
||||
const ShaderAst::Function* currentFunction;
|
||||
std::vector<Local> declaredLocals;
|
||||
std::vector<std::size_t> blockLocalIndex;
|
||||
//const ShaderAst::Function* currentFunction;
|
||||
std::optional<std::size_t> activeScopeId;
|
||||
AstCache* cache;
|
||||
};
|
||||
|
||||
bool ShaderAstValidator::Validate(std::string* error)
|
||||
bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache)
|
||||
{
|
||||
try
|
||||
{
|
||||
for (std::size_t i = 0; i < m_shader.GetFunctionCount(); ++i)
|
||||
{
|
||||
const auto& func = m_shader.GetFunction(i);
|
||||
AstCache dummy;
|
||||
|
||||
Context currentContext;
|
||||
currentContext.currentFunction = &func;
|
||||
Context currentContext;
|
||||
currentContext.cache = (cache) ? cache : &dummy;
|
||||
|
||||
m_context = ¤tContext;
|
||||
CallOnExit resetContext([&] { m_context = nullptr; });
|
||||
m_context = ¤tContext;
|
||||
CallOnExit resetContext([&] { m_context = nullptr; });
|
||||
|
||||
func.statement->Visit(*this);
|
||||
}
|
||||
EnterScope();
|
||||
node->Visit(*this);
|
||||
ExitScope();
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -58,148 +50,183 @@ namespace Nz
|
||||
}
|
||||
}
|
||||
|
||||
const ShaderNodes::ExpressionPtr& ShaderAstValidator::MandatoryExpr(const ShaderNodes::ExpressionPtr& node)
|
||||
{
|
||||
MandatoryNode(node);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
const ShaderNodes::NodePtr& ShaderAstValidator::MandatoryNode(const ShaderNodes::NodePtr& node)
|
||||
Expression& AstValidator::MandatoryExpr(ExpressionPtr& node)
|
||||
{
|
||||
if (!node)
|
||||
throw AstError{ "Invalid node" };
|
||||
throw AstError{ "Invalid expression" };
|
||||
|
||||
return node;
|
||||
return *node;
|
||||
}
|
||||
|
||||
void ShaderAstValidator::TypeMustMatch(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right)
|
||||
Statement& AstValidator::MandatoryStatement(StatementPtr& node)
|
||||
{
|
||||
return TypeMustMatch(left->GetExpressionType(), right->GetExpressionType());
|
||||
if (!node)
|
||||
throw AstError{ "Invalid statement" };
|
||||
|
||||
return *node;
|
||||
}
|
||||
|
||||
void ShaderAstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
|
||||
void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
|
||||
{
|
||||
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
|
||||
}
|
||||
|
||||
void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
|
||||
{
|
||||
if (left != right)
|
||||
throw AstError{ "Left expression type must match right expression type" };
|
||||
}
|
||||
|
||||
const ShaderAst::StructMember& ShaderAstValidator::CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers)
|
||||
ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
|
||||
{
|
||||
const auto& structs = m_shader.GetStructs();
|
||||
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
|
||||
if (it == structs.end())
|
||||
throw AstError{ "invalid structure" };
|
||||
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
|
||||
if (!identifier)
|
||||
throw AstError{ "unknown identifier " + structName };
|
||||
|
||||
const ShaderAst::Struct& s = *it;
|
||||
if (*memberIndex >= s.members.size())
|
||||
throw AstError{ "member index out of bounds" };
|
||||
if (std::holds_alternative<StructDescription>(identifier->value))
|
||||
throw AstError{ "identifier is not a struct" };
|
||||
|
||||
const auto& member = s.members[*memberIndex];
|
||||
const StructDescription& s = std::get<StructDescription>(identifier->value);
|
||||
|
||||
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
|
||||
if (memberIt == s.members.end())
|
||||
throw AstError{ "unknown field " + memberIdentifier[0]};
|
||||
|
||||
const auto& member = *memberIt;
|
||||
|
||||
if (remainingMembers > 1)
|
||||
{
|
||||
if (!IsStructType(member.type))
|
||||
throw AstError{ "member type does not match node type" };
|
||||
|
||||
return CheckField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
|
||||
}
|
||||
return CheckField(std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
|
||||
else
|
||||
return member;
|
||||
return member.type;
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node)
|
||||
AstCache::Scope& AstValidator::EnterScope()
|
||||
{
|
||||
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
|
||||
std::size_t newScopeId = m_context->cache->scopes.size();
|
||||
|
||||
std::optional<std::size_t> previousScope = m_context->activeScopeId;
|
||||
|
||||
auto& newScope = m_context->cache->scopes.emplace_back();
|
||||
newScope.parentScopeIndex = previousScope;
|
||||
|
||||
m_context->activeScopeId = newScopeId;
|
||||
return m_context->cache->scopes[newScopeId];
|
||||
}
|
||||
|
||||
void AstValidator::ExitScope()
|
||||
{
|
||||
assert(m_context->activeScopeId);
|
||||
auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId];
|
||||
m_context->activeScopeId = previousScope.parentScopeIndex;
|
||||
}
|
||||
|
||||
void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType)
|
||||
{
|
||||
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
|
||||
}
|
||||
|
||||
void AstValidator::RegisterScope(Node& node)
|
||||
{
|
||||
if (m_context->activeScopeId)
|
||||
m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId;
|
||||
}
|
||||
|
||||
void AstValidator::Visit(AccessMemberExpression& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
|
||||
if (!IsStructType(exprType))
|
||||
throw AstError{ "expression is not a structure" };
|
||||
|
||||
const std::string& structName = std::get<std::string>(exprType);
|
||||
|
||||
const auto& member = CheckField(structName, node.memberIndices.data(), node.memberIndices.size());
|
||||
if (member.type != node.exprType)
|
||||
throw AstError{ "member type does not match node type" };
|
||||
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::AssignOp& node)
|
||||
void AstValidator::Visit(AssignExpression& node)
|
||||
{
|
||||
MandatoryNode(node.left);
|
||||
MandatoryNode(node.right);
|
||||
RegisterScope(node);
|
||||
|
||||
MandatoryExpr(node.left);
|
||||
MandatoryExpr(node.right);
|
||||
TypeMustMatch(node.left, node.right);
|
||||
|
||||
if (GetExpressionCategory(node.left) != ShaderNodes::ExpressionCategory::LValue)
|
||||
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
|
||||
throw AstError { "Assignation is only possible with a l-value" };
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::BinaryOp& node)
|
||||
void AstValidator::Visit(BinaryExpression& node)
|
||||
{
|
||||
MandatoryNode(node.left);
|
||||
MandatoryNode(node.right);
|
||||
RegisterScope(node);
|
||||
|
||||
const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
|
||||
// Register expression type
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
const ShaderExpressionType& leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
|
||||
if (!IsBasicType(leftExprType))
|
||||
throw AstError{ "left expression type does not support binary operation" };
|
||||
|
||||
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
|
||||
const ShaderExpressionType& rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
|
||||
if (!IsBasicType(rightExprType))
|
||||
throw AstError{ "right expression type does not support binary operation" };
|
||||
|
||||
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
|
||||
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
|
||||
BasicType leftType = std::get<BasicType>(leftExprType);
|
||||
BasicType rightType = std::get<BasicType>(rightExprType);
|
||||
|
||||
switch (node.op)
|
||||
{
|
||||
case ShaderNodes::BinaryType::CompGe:
|
||||
case ShaderNodes::BinaryType::CompGt:
|
||||
case ShaderNodes::BinaryType::CompLe:
|
||||
case ShaderNodes::BinaryType::CompLt:
|
||||
if (leftType == ShaderNodes::BasicType::Boolean)
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
if (leftType == BasicType::Boolean)
|
||||
throw AstError{ "this operation is not supported for booleans" };
|
||||
|
||||
[[fallthrough]];
|
||||
case ShaderNodes::BinaryType::Add:
|
||||
case ShaderNodes::BinaryType::CompEq:
|
||||
case ShaderNodes::BinaryType::CompNe:
|
||||
case ShaderNodes::BinaryType::Subtract:
|
||||
case BinaryType::Add:
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompNe:
|
||||
case BinaryType::Subtract:
|
||||
TypeMustMatch(node.left, node.right);
|
||||
break;
|
||||
|
||||
case ShaderNodes::BinaryType::Multiply:
|
||||
case ShaderNodes::BinaryType::Divide:
|
||||
case BinaryType::Multiply:
|
||||
case BinaryType::Divide:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case BasicType::Float1:
|
||||
case BasicType::Int1:
|
||||
{
|
||||
if (ShaderNodes::Node::GetComponentType(rightType) != leftType)
|
||||
if (GetComponentType(rightType) != leftType)
|
||||
throw AstError{ "Left expression type is not compatible with right expression type" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case BasicType::Float2:
|
||||
case BasicType::Float3:
|
||||
case BasicType::Float4:
|
||||
case BasicType::Int2:
|
||||
case BasicType::Int3:
|
||||
case BasicType::Int4:
|
||||
{
|
||||
if (leftType != rightType && rightType != ShaderNodes::Node::GetComponentType(leftType))
|
||||
if (leftType != rightType && rightType != GetComponentType(leftType))
|
||||
throw AstError{ "Left expression type is not compatible with right expression type" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case BasicType::Mat4x4:
|
||||
{
|
||||
switch (rightType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case BasicType::Float1:
|
||||
case BasicType::Float4:
|
||||
case BasicType::Mat4x4:
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -211,120 +238,86 @@ namespace Nz
|
||||
|
||||
default:
|
||||
TypeMustMatch(node.left, node.right);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::Branch& node)
|
||||
void AstValidator::Visit(CastExpression& node)
|
||||
{
|
||||
for (const auto& condStatement : node.condStatements)
|
||||
{
|
||||
const ShaderExpressionType& condType = MandatoryExpr(condStatement.condition)->GetExpressionType();
|
||||
if (!IsBasicType(condType) || std::get<ShaderNodes::BasicType>(condType) != ShaderNodes::BasicType::Boolean)
|
||||
throw AstError{ "if expression must resolve to boolean type" };
|
||||
RegisterScope(node);
|
||||
|
||||
MandatoryNode(condStatement.statement);
|
||||
}
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::Cast& node)
|
||||
{
|
||||
unsigned int componentCount = 0;
|
||||
unsigned int requiredComponents = node.GetComponentCount(node.exprType);
|
||||
for (const auto& exprPtr : node.expressions)
|
||||
unsigned int requiredComponents = GetComponentCount(node.targetType);
|
||||
for (auto& exprPtr : node.expressions)
|
||||
{
|
||||
if (!exprPtr)
|
||||
break;
|
||||
|
||||
const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
|
||||
ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
|
||||
if (!IsBasicType(exprType))
|
||||
throw AstError{ "incompatible type" };
|
||||
|
||||
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
|
||||
componentCount += GetComponentCount(std::get<BasicType>(exprType));
|
||||
}
|
||||
|
||||
if (componentCount != requiredComponents)
|
||||
throw AstError{ "component count doesn't match required component count" };
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
void AstValidator::Visit(ConditionalExpression& node)
|
||||
{
|
||||
MandatoryNode(node.truePath);
|
||||
MandatoryNode(node.falsePath);
|
||||
MandatoryExpr(node.truePath);
|
||||
MandatoryExpr(node.falsePath);
|
||||
|
||||
if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
||||
throw AstError{ "condition not found" };
|
||||
RegisterScope(node);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
||||
// throw AstError{ "condition not found" };
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
void AstValidator::Visit(ConstantExpression& node)
|
||||
{
|
||||
MandatoryNode(node.statement);
|
||||
|
||||
if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
||||
throw AstError{ "condition not found" };
|
||||
RegisterScope(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/)
|
||||
{
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::DeclareVariable& node)
|
||||
void AstValidator::Visit(IdentifierExpression& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
if (node.variable->GetType() != ShaderNodes::VariableType::LocalVariable)
|
||||
throw AstError{ "Only local variables can be declared in a statement" };
|
||||
if (!m_context->activeScopeId)
|
||||
throw AstError{ "no scope" };
|
||||
|
||||
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
|
||||
RegisterScope(node);
|
||||
|
||||
auto& local = m_context->declaredLocals.emplace_back();
|
||||
local.name = localVar.name;
|
||||
local.type = localVar.type;
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier);
|
||||
if (!identifier)
|
||||
throw AstError{ "Unknown variable " + node.identifier };
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::ExpressionStatement& node)
|
||||
|
||||
void AstValidator::Visit(IntrinsicExpression& node)
|
||||
{
|
||||
MandatoryNode(node.expression);
|
||||
RegisterScope(node);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::Identifier& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
if (!node.var)
|
||||
throw AstError{ "Invalid variable" };
|
||||
|
||||
Visit(node.var);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::IntrinsicCall& node)
|
||||
{
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case ShaderNodes::IntrinsicType::CrossProduct:
|
||||
case ShaderNodes::IntrinsicType::DotProduct:
|
||||
case IntrinsicType::CrossProduct:
|
||||
case IntrinsicType::DotProduct:
|
||||
{
|
||||
if (node.parameters.size() != 2)
|
||||
throw AstError { "Expected 2 parameters" };
|
||||
|
||||
for (auto& param : node.parameters)
|
||||
MandatoryNode(param);
|
||||
MandatoryExpr(param);
|
||||
|
||||
ShaderExpressionType type = node.parameters.front()->GetExpressionType();
|
||||
ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
|
||||
for (std::size_t i = 1; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (type != node.parameters[i]->GetExpressionType())
|
||||
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache)
|
||||
throw AstError{ "All type must match" };
|
||||
}
|
||||
|
||||
@@ -334,180 +327,176 @@ namespace Nz
|
||||
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case ShaderNodes::IntrinsicType::CrossProduct:
|
||||
case IntrinsicType::CrossProduct:
|
||||
{
|
||||
if (node.parameters[0]->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float3 })
|
||||
if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache)
|
||||
throw AstError{ "CrossProduct only works with Float3 expressions" };
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::IntrinsicType::DotProduct:
|
||||
case IntrinsicType::DotProduct:
|
||||
break;
|
||||
}
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::ReturnStatement& node)
|
||||
void AstValidator::Visit(SwizzleExpression& node)
|
||||
{
|
||||
if (m_context->currentFunction->returnType != ShaderExpressionType(ShaderNodes::BasicType::Void))
|
||||
{
|
||||
if (MandatoryExpr(node.returnExpr)->GetExpressionType() != m_context->currentFunction->returnType)
|
||||
throw AstError{ "Return type doesn't match function return type" };
|
||||
}
|
||||
else
|
||||
{
|
||||
if (node.returnExpr)
|
||||
throw AstError{ "Unexpected expression for return (function doesn't return)" };
|
||||
}
|
||||
RegisterScope(node);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::Sample2D& node)
|
||||
{
|
||||
if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D })
|
||||
throw AstError{ "Sampler must be a Sampler2D" };
|
||||
|
||||
if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float2 })
|
||||
throw AstError{ "Coordinates must be a Float2" };
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::StatementBlock& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
m_context->blockLocalIndex.push_back(m_context->declaredLocals.size());
|
||||
|
||||
for (const auto& statement : node.statements)
|
||||
MandatoryNode(statement);
|
||||
|
||||
assert(m_context->declaredLocals.size() >= m_context->blockLocalIndex.back());
|
||||
m_context->declaredLocals.resize(m_context->blockLocalIndex.back());
|
||||
m_context->blockLocalIndex.pop_back();
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::SwizzleOp& node)
|
||||
{
|
||||
if (node.componentCount > 4)
|
||||
throw AstError{ "Cannot swizzle more than four elements" };
|
||||
|
||||
const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
|
||||
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
|
||||
if (!IsBasicType(exprType))
|
||||
throw AstError{ "Cannot swizzle this type" };
|
||||
|
||||
switch (std::get<ShaderNodes::BasicType>(exprType))
|
||||
switch (std::get<BasicType>(exprType))
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case BasicType::Float1:
|
||||
case BasicType::Float2:
|
||||
case BasicType::Float3:
|
||||
case BasicType::Float4:
|
||||
case BasicType::Int1:
|
||||
case BasicType::Int2:
|
||||
case BasicType::Int3:
|
||||
case BasicType::Int4:
|
||||
break;
|
||||
|
||||
default:
|
||||
throw AstError{ "Cannot swizzle this type" };
|
||||
}
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::BuiltinVariable& var)
|
||||
void AstValidator::Visit(BranchStatement& node)
|
||||
{
|
||||
switch (var.entry)
|
||||
RegisterScope(node);
|
||||
|
||||
for (auto& condStatement : node.condStatements)
|
||||
{
|
||||
case ShaderNodes::BuiltinEntry::VertexPosition:
|
||||
if (!IsBasicType(var.type) ||
|
||||
std::get<ShaderNodes::BasicType>(var.type) != ShaderNodes::BasicType::Float4)
|
||||
throw AstError{ "Builtin is not of the expected type" };
|
||||
const ShaderExpressionType& condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
|
||||
if (!IsBasicType(condType) || std::get<BasicType>(condType) != BasicType::Boolean)
|
||||
throw AstError{ "if expression must resolve to boolean type" };
|
||||
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::InputVariable& var)
|
||||
{
|
||||
for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i)
|
||||
{
|
||||
const auto& input = m_shader.GetInput(i);
|
||||
if (input.name == var.name)
|
||||
{
|
||||
TypeMustMatch(input.type, var.type);
|
||||
return;
|
||||
}
|
||||
MandatoryStatement(condStatement.statement);
|
||||
}
|
||||
|
||||
throw AstError{ "Input not found" };
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::LocalVariable& var)
|
||||
void AstValidator::Visit(ConditionalStatement& node)
|
||||
{
|
||||
const auto& vars = m_context->declaredLocals;
|
||||
MandatoryStatement(node.statement);
|
||||
|
||||
auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& v) { return v.name == var.name; });
|
||||
if (it == vars.end())
|
||||
throw AstError{ "Local variable not found in this block" };
|
||||
RegisterScope(node);
|
||||
|
||||
TypeMustMatch(it->type, var.type);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
|
||||
// throw AstError{ "condition not found" };
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::OutputVariable& var)
|
||||
void AstValidator::Visit(DeclareFunctionStatement& node)
|
||||
{
|
||||
for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i)
|
||||
auto& scope = EnterScope();
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
for (auto& parameter : node.parameters)
|
||||
{
|
||||
const auto& input = m_shader.GetOutput(i);
|
||||
if (input.name == var.name)
|
||||
{
|
||||
TypeMustMatch(input.type, var.type);
|
||||
return;
|
||||
}
|
||||
auto& identifier = scope.identifiers.emplace_back();
|
||||
identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } };
|
||||
}
|
||||
|
||||
throw AstError{ "Output not found" };
|
||||
for (auto& statement : node.statements)
|
||||
MandatoryStatement(statement).Visit(*this);
|
||||
|
||||
ExitScope();
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::ParameterVariable& var)
|
||||
void AstValidator::Visit(DeclareStructStatement& node)
|
||||
{
|
||||
assert(m_context->currentFunction);
|
||||
assert(m_context);
|
||||
|
||||
const auto& parameters = m_context->currentFunction->parameters;
|
||||
if (!m_context->activeScopeId)
|
||||
throw AstError{ "cannot declare variable without scope" };
|
||||
|
||||
auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == var.name; });
|
||||
if (it == parameters.end())
|
||||
throw AstError{ "Parameter not found in function" };
|
||||
RegisterScope(node);
|
||||
|
||||
TypeMustMatch(it->type, var.type);
|
||||
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
|
||||
|
||||
auto& identifier = scope.identifiers.emplace_back();
|
||||
identifier = AstCache::Identifier{ node.description.name, node.description };
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderAstValidator::Visit(ShaderNodes::UniformVariable& var)
|
||||
void AstValidator::Visit(DeclareVariableStatement& node)
|
||||
{
|
||||
for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i)
|
||||
assert(m_context);
|
||||
|
||||
if (!m_context->activeScopeId)
|
||||
throw AstError{ "cannot declare variable without scope" };
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
|
||||
|
||||
auto& identifier = scope.identifiers.emplace_back();
|
||||
identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } };
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ExpressionStatement& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
MandatoryExpr(node.expression);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(MultiStatement& node)
|
||||
{
|
||||
assert(m_context);
|
||||
|
||||
EnterScope();
|
||||
|
||||
RegisterScope(node);
|
||||
|
||||
for (auto& statement : node.statements)
|
||||
MandatoryStatement(statement);
|
||||
|
||||
ExitScope();
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void AstValidator::Visit(ReturnStatement& node)
|
||||
{
|
||||
RegisterScope(node);
|
||||
|
||||
/*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void))
|
||||
{
|
||||
const auto& uniform = m_shader.GetUniform(i);
|
||||
if (uniform.name == var.name)
|
||||
{
|
||||
TypeMustMatch(uniform.type, var.type);
|
||||
return;
|
||||
}
|
||||
if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType)
|
||||
throw AstError{ "Return type doesn't match function return type" };
|
||||
}
|
||||
else
|
||||
{
|
||||
if (node.returnExpr)
|
||||
throw AstError{ "Unexpected expression for return (function doesn't return)" };
|
||||
}*/
|
||||
|
||||
throw AstError{ "Uniform not found" };
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
bool ValidateShader(const ShaderAst& shader, std::string* error)
|
||||
bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache)
|
||||
{
|
||||
ShaderAstValidator validator(shader);
|
||||
return validator.Validate(error);
|
||||
AstValidator validator;
|
||||
return validator.Validate(node, error, cache);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::AccessMember& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled AccessMember node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::AssignOp& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled AssignOp node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::BinaryOp& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled AccessMember node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::Branch& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled Branch node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::Cast& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled Cast node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalExpression& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled ConditionalExpression node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalStatement& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled ConditionalStatement node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled Constant node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::DeclareVariable& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled DeclareVariable node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::Discard& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled Discard node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::ExpressionStatement& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled ExpressionStatement node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::Identifier& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled Identifier node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::IntrinsicCall& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled IntrinsicCall node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::NoOp& node)
|
||||
{
|
||||
throw std::runtime_error("unhandled NoOp node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::ReturnStatement& node)
|
||||
{
|
||||
throw std::runtime_error("unhandled ReturnStatement node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::Sample2D& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled Sample2D node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::StatementBlock& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled StatementBlock node");
|
||||
}
|
||||
|
||||
void ShaderAstVisitorExcept::Visit(ShaderNodes::SwizzleOp& /*node*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled SwizzleOp node");
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,7 @@ namespace Nz::ShaderLang
|
||||
std::unordered_map<std::string, TokenType> reservedKeywords = {
|
||||
{ "false", TokenType::BoolFalse },
|
||||
{ "fn", TokenType::FunctionDeclaration },
|
||||
{ "let", TokenType::Let },
|
||||
{ "return", TokenType::Return },
|
||||
{ "true", TokenType::BoolTrue }
|
||||
};
|
||||
@@ -143,7 +144,7 @@ namespace Nz::ShaderLang
|
||||
while (next != -1);
|
||||
}
|
||||
else
|
||||
tokenType == TokenType::Divide;
|
||||
tokenType = TokenType::Divide;
|
||||
|
||||
break;
|
||||
}
|
||||
@@ -191,9 +192,11 @@ namespace Nz::ShaderLang
|
||||
|
||||
std::string valueStr(str.substr(start, currentPos - start + 1));
|
||||
|
||||
const char* ptr = valueStr.c_str();
|
||||
|
||||
char* end;
|
||||
double value = std::strtod(valueStr.c_str(), &end);
|
||||
if (end != &str[currentPos + 1])
|
||||
double value = std::strtod(ptr, &end);
|
||||
if (end != &ptr[valueStr.size()])
|
||||
throw BadNumber{};
|
||||
|
||||
token.data = value;
|
||||
@@ -218,6 +221,7 @@ namespace Nz::ShaderLang
|
||||
break;
|
||||
}
|
||||
|
||||
case '=': tokenType = TokenType::Assign; break;
|
||||
case '+': tokenType = TokenType::Plus; break;
|
||||
case '*': tokenType = TokenType::Multiply; break;
|
||||
case ':': tokenType = TokenType::Colon; break;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderLangParser.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <cassert>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
@@ -10,36 +11,38 @@ namespace Nz::ShaderLang
|
||||
{
|
||||
namespace
|
||||
{
|
||||
std::unordered_map<std::string, ShaderNodes::BasicType> identifierToBasicType = {
|
||||
{ "bool", ShaderNodes::BasicType::Boolean },
|
||||
std::unordered_map<std::string, ShaderAst::BasicType> identifierToBasicType = {
|
||||
{ "bool", ShaderAst::BasicType::Boolean },
|
||||
|
||||
{ "i32", ShaderNodes::BasicType::Int1 },
|
||||
{ "vec2i32", ShaderNodes::BasicType::Int2 },
|
||||
{ "vec3i32", ShaderNodes::BasicType::Int3 },
|
||||
{ "vec4i32", ShaderNodes::BasicType::Int4 },
|
||||
{ "i32", ShaderAst::BasicType::Int1 },
|
||||
{ "vec2i32", ShaderAst::BasicType::Int2 },
|
||||
{ "vec3i32", ShaderAst::BasicType::Int3 },
|
||||
{ "vec4i32", ShaderAst::BasicType::Int4 },
|
||||
|
||||
{ "f32", ShaderNodes::BasicType::Float1 },
|
||||
{ "vec2f32", ShaderNodes::BasicType::Float2 },
|
||||
{ "vec3f32", ShaderNodes::BasicType::Float3 },
|
||||
{ "vec4f32", ShaderNodes::BasicType::Float4 },
|
||||
{ "f32", ShaderAst::BasicType::Float1 },
|
||||
{ "vec2f32", ShaderAst::BasicType::Float2 },
|
||||
{ "vec3f32", ShaderAst::BasicType::Float3 },
|
||||
{ "vec4f32", ShaderAst::BasicType::Float4 },
|
||||
|
||||
{ "mat4x4f32", ShaderNodes::BasicType::Mat4x4 },
|
||||
{ "sampler2D", ShaderNodes::BasicType::Sampler2D },
|
||||
{ "void", ShaderNodes::BasicType::Void },
|
||||
{ "mat4x4f32", ShaderAst::BasicType::Mat4x4 },
|
||||
{ "sampler2D", ShaderAst::BasicType::Sampler2D },
|
||||
{ "void", ShaderAst::BasicType::Void },
|
||||
|
||||
{ "u32", ShaderNodes::BasicType::UInt1 },
|
||||
{ "vec2u32", ShaderNodes::BasicType::UInt3 },
|
||||
{ "vec3u32", ShaderNodes::BasicType::UInt3 },
|
||||
{ "vec4u32", ShaderNodes::BasicType::UInt4 },
|
||||
{ "u32", ShaderAst::BasicType::UInt1 },
|
||||
{ "vec2u32", ShaderAst::BasicType::UInt3 },
|
||||
{ "vec3u32", ShaderAst::BasicType::UInt3 },
|
||||
{ "vec4u32", ShaderAst::BasicType::UInt4 },
|
||||
};
|
||||
}
|
||||
|
||||
ShaderAst Parser::Parse(const std::vector<Token>& tokens)
|
||||
ShaderAst::StatementPtr Parser::Parse(const std::vector<Token>& tokens)
|
||||
{
|
||||
Context context;
|
||||
context.tokenCount = tokens.size();
|
||||
context.tokens = tokens.data();
|
||||
|
||||
context.root = std::make_unique<ShaderAst::MultiStatement>();
|
||||
|
||||
m_context = &context;
|
||||
|
||||
m_context->tokenIndex = -1;
|
||||
@@ -51,7 +54,7 @@ namespace Nz::ShaderLang
|
||||
switch (nextToken.type)
|
||||
{
|
||||
case TokenType::FunctionDeclaration:
|
||||
ParseFunctionDeclaration();
|
||||
context.root->statements.push_back(ParseFunctionDeclaration());
|
||||
break;
|
||||
|
||||
case TokenType::EndOfStream:
|
||||
@@ -63,7 +66,7 @@ namespace Nz::ShaderLang
|
||||
}
|
||||
}
|
||||
|
||||
return std::move(context.result);
|
||||
return std::move(context.root);
|
||||
}
|
||||
|
||||
const Token& Parser::Advance()
|
||||
@@ -92,12 +95,12 @@ namespace Nz::ShaderLang
|
||||
return m_context->tokens[m_context->tokenIndex + 1];
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr Parser::ParseFunctionBody()
|
||||
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
|
||||
{
|
||||
return ParseStatementList();
|
||||
}
|
||||
|
||||
void Parser::ParseFunctionDeclaration()
|
||||
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration()
|
||||
{
|
||||
ExpectNext(TokenType::FunctionDeclaration);
|
||||
|
||||
@@ -105,7 +108,7 @@ namespace Nz::ShaderLang
|
||||
|
||||
ExpectNext(TokenType::OpenParenthesis);
|
||||
|
||||
std::vector<ShaderAst::FunctionParameter> parameters;
|
||||
std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters;
|
||||
|
||||
bool firstParameter = true;
|
||||
for (;;)
|
||||
@@ -126,7 +129,7 @@ namespace Nz::ShaderLang
|
||||
|
||||
ExpectNext(TokenType::ClosingParenthesis);
|
||||
|
||||
ShaderExpressionType returnType = ShaderNodes::BasicType::Void;
|
||||
ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void;
|
||||
if (PeekNext().type == TokenType::FunctionReturn)
|
||||
{
|
||||
Advance(); //< Consume ->
|
||||
@@ -136,42 +139,46 @@ namespace Nz::ShaderLang
|
||||
|
||||
ExpectNext(TokenType::OpenCurlyBracket);
|
||||
|
||||
ShaderNodes::StatementPtr functionBody = ParseFunctionBody();
|
||||
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
|
||||
|
||||
ExpectNext(TokenType::ClosingCurlyBracket);
|
||||
|
||||
m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType);
|
||||
return ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
|
||||
}
|
||||
|
||||
ShaderAst::FunctionParameter Parser::ParseFunctionParameter()
|
||||
ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter()
|
||||
{
|
||||
std::string parameterName = ParseIdentifierAsName();
|
||||
|
||||
ExpectNext(TokenType::Colon);
|
||||
|
||||
ShaderExpressionType parameterType = ParseIdentifierAsType();
|
||||
ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType();
|
||||
|
||||
return { parameterName, parameterType };
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr Parser::ParseReturnStatement()
|
||||
ShaderAst::StatementPtr Parser::ParseReturnStatement()
|
||||
{
|
||||
ExpectNext(TokenType::Return);
|
||||
|
||||
ShaderNodes::ExpressionPtr expr;
|
||||
ShaderAst::ExpressionPtr expr;
|
||||
if (PeekNext().type != TokenType::Semicolon)
|
||||
expr = ParseExpression();
|
||||
|
||||
return ShaderNodes::ReturnStatement::Build(std::move(expr));
|
||||
return ShaderBuilder::Return(std::move(expr));
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr Parser::ParseStatement()
|
||||
ShaderAst::StatementPtr Parser::ParseStatement()
|
||||
{
|
||||
const Token& token = PeekNext();
|
||||
|
||||
ShaderNodes::StatementPtr statement;
|
||||
ShaderAst::StatementPtr statement;
|
||||
switch (token.type)
|
||||
{
|
||||
case TokenType::Let:
|
||||
statement = ParseVariableDeclaration();
|
||||
break;
|
||||
|
||||
case TokenType::Return:
|
||||
statement = ParseReturnStatement();
|
||||
break;
|
||||
@@ -185,18 +192,38 @@ namespace Nz::ShaderLang
|
||||
return statement;
|
||||
}
|
||||
|
||||
ShaderNodes::StatementPtr Parser::ParseStatementList()
|
||||
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
|
||||
{
|
||||
std::vector<ShaderNodes::StatementPtr> statements;
|
||||
std::vector<ShaderAst::StatementPtr> statements;
|
||||
while (PeekNext().type != TokenType::ClosingCurlyBracket)
|
||||
{
|
||||
statements.push_back(ParseStatement());
|
||||
}
|
||||
|
||||
return ShaderNodes::StatementBlock::Build(std::move(statements));
|
||||
return statements;
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs)
|
||||
ShaderAst::StatementPtr Parser::ParseVariableDeclaration()
|
||||
{
|
||||
ExpectNext(TokenType::Let);
|
||||
|
||||
std::string variableName = ParseIdentifierAsName();
|
||||
|
||||
ExpectNext(TokenType::Colon);
|
||||
|
||||
ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType();
|
||||
|
||||
ShaderAst::ExpressionPtr expression;
|
||||
if (PeekNext().type == TokenType::Assign)
|
||||
{
|
||||
Advance();
|
||||
expression = ParseExpression();
|
||||
}
|
||||
|
||||
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
|
||||
}
|
||||
|
||||
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
|
||||
{
|
||||
for (;;)
|
||||
{
|
||||
@@ -207,7 +234,7 @@ namespace Nz::ShaderLang
|
||||
return lhs;
|
||||
|
||||
Advance();
|
||||
ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression();
|
||||
ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression();
|
||||
|
||||
const Token& nextOp = PeekNext();
|
||||
|
||||
@@ -215,57 +242,58 @@ namespace Nz::ShaderLang
|
||||
if (tokenPrecedence < nextTokenPrecedence)
|
||||
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
|
||||
|
||||
ShaderNodes::BinaryType binaryType;
|
||||
ShaderAst::BinaryType binaryType;
|
||||
{
|
||||
switch (currentOp.type)
|
||||
{
|
||||
case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break;
|
||||
case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break;
|
||||
case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break;
|
||||
case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break;
|
||||
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
|
||||
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
|
||||
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
|
||||
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
|
||||
default: throw UnexpectedToken{};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs));
|
||||
lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr Parser::ParseExpression()
|
||||
ShaderAst::ExpressionPtr Parser::ParseExpression()
|
||||
{
|
||||
return ParseBinOpRhs(0, ParsePrimaryExpression());
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr Parser::ParseIdentifier()
|
||||
ShaderAst::ExpressionPtr Parser::ParseIdentifier()
|
||||
{
|
||||
const Token& identifier = ExpectNext(TokenType::Identifier);
|
||||
|
||||
return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get<std::string>(identifier.data), ShaderNodes::BasicType::Float3));
|
||||
return ShaderBuilder::Identifier(std::get<std::string>(identifier.data));
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression()
|
||||
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression()
|
||||
{
|
||||
const Token& integer = ExpectNext(TokenType::IntegerValue);
|
||||
return ShaderNodes::Constant::Build(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
|
||||
return ShaderBuilder::Constant(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression()
|
||||
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
|
||||
{
|
||||
ExpectNext(TokenType::OpenParenthesis);
|
||||
ShaderNodes::ExpressionPtr expression = ParseExpression();
|
||||
ShaderAst::ExpressionPtr expression = ParseExpression();
|
||||
ExpectNext(TokenType::ClosingParenthesis);
|
||||
|
||||
return expression;
|
||||
}
|
||||
|
||||
ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression()
|
||||
ShaderAst::ExpressionPtr Parser::ParsePrimaryExpression()
|
||||
{
|
||||
const Token& token = PeekNext();
|
||||
switch (token.type)
|
||||
{
|
||||
case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false);
|
||||
case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true);
|
||||
case TokenType::BoolFalse: return ShaderBuilder::Constant(false);
|
||||
case TokenType::BoolTrue: return ShaderBuilder::Constant(true);
|
||||
case TokenType::FloatingPointValue: return ShaderBuilder::Constant(float(std::get<double>(Advance().data))); //< FIXME
|
||||
case TokenType::Identifier: return ParseIdentifier();
|
||||
case TokenType::IntegerValue: return ParseIntegerExpression();
|
||||
case TokenType::OpenParenthesis: return ParseParenthesisExpression();
|
||||
@@ -286,7 +314,7 @@ namespace Nz::ShaderLang
|
||||
return identifier;
|
||||
}
|
||||
|
||||
ShaderExpressionType Parser::ParseIdentifierAsType()
|
||||
ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType()
|
||||
{
|
||||
const Token& identifier = ExpectNext(TokenType::Identifier);
|
||||
|
||||
|
||||
@@ -4,265 +4,29 @@
|
||||
|
||||
#include <Nazara/Shader/ShaderNodes.hpp>
|
||||
#include <Nazara/Core/Algorithm.hpp>
|
||||
#include <Nazara/Shader/ShaderAstSerializer.hpp>
|
||||
#include <Nazara/Shader/ShaderAstVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderWriter.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderNodes
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
Node::~Node() = default;
|
||||
|
||||
void ExpressionStatement::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
#define NAZARA_SHADERAST_NODE(Node) NodeType Node::GetType() const \
|
||||
{ \
|
||||
return NodeType:: Node; \
|
||||
}
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
|
||||
#define NAZARA_SHADERAST_EXPRESSION(Node) void Node::Visit(AstExpressionVisitor& visitor) \
|
||||
{\
|
||||
visitor.Visit(*this); \
|
||||
}
|
||||
|
||||
|
||||
void ConditionalStatement::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
#define NAZARA_SHADERAST_STATEMENT(Node) void Node::Visit(AstStatementVisitor& visitor) \
|
||||
{\
|
||||
visitor.Visit(*this); \
|
||||
}
|
||||
|
||||
|
||||
void StatementBlock::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
void DeclareVariable::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
void Discard::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType Identifier::GetExpressionType() const
|
||||
{
|
||||
assert(var);
|
||||
return var->type;
|
||||
}
|
||||
|
||||
void Identifier::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
ShaderExpressionType AccessMember::GetExpressionType() const
|
||||
{
|
||||
return exprType;
|
||||
}
|
||||
|
||||
void AccessMember::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
void NoOp::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
void ReturnStatement::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
ShaderExpressionType AssignOp::GetExpressionType() const
|
||||
{
|
||||
return left->GetExpressionType();
|
||||
}
|
||||
|
||||
void AssignOp::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType BinaryOp::GetExpressionType() const
|
||||
{
|
||||
std::optional<ShaderExpressionType> exprType;
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Subtract:
|
||||
exprType = left->GetExpressionType();
|
||||
break;
|
||||
|
||||
case BinaryType::Divide:
|
||||
case BinaryType::Multiply:
|
||||
{
|
||||
const ShaderExpressionType& leftExprType = left->GetExpressionType();
|
||||
assert(IsBasicType(leftExprType));
|
||||
|
||||
const ShaderExpressionType& rightExprType = right->GetExpressionType();
|
||||
assert(IsBasicType(rightExprType));
|
||||
|
||||
switch (std::get<BasicType>(leftExprType))
|
||||
{
|
||||
case BasicType::Boolean:
|
||||
case BasicType::Float2:
|
||||
case BasicType::Float3:
|
||||
case BasicType::Float4:
|
||||
case BasicType::Int2:
|
||||
case BasicType::Int3:
|
||||
case BasicType::Int4:
|
||||
case BasicType::UInt2:
|
||||
case BasicType::UInt3:
|
||||
case BasicType::UInt4:
|
||||
exprType = leftExprType;
|
||||
break;
|
||||
|
||||
case BasicType::Float1:
|
||||
case BasicType::Int1:
|
||||
case BasicType::Mat4x4:
|
||||
case BasicType::UInt1:
|
||||
exprType = rightExprType;
|
||||
break;
|
||||
|
||||
case BasicType::Sampler2D:
|
||||
case BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::CompEq:
|
||||
case BinaryType::CompGe:
|
||||
case BinaryType::CompGt:
|
||||
case BinaryType::CompLe:
|
||||
case BinaryType::CompLt:
|
||||
case BinaryType::CompNe:
|
||||
exprType = BasicType::Boolean;
|
||||
break;
|
||||
}
|
||||
|
||||
NazaraAssert(exprType.has_value(), "Unhandled builtin");
|
||||
|
||||
return *exprType;
|
||||
}
|
||||
|
||||
void BinaryOp::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
void Branch::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType Constant::GetExpressionType() const
|
||||
{
|
||||
return std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return ShaderNodes::BasicType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return ShaderNodes::BasicType::Float1;
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return ShaderNodes::BasicType::Int1;
|
||||
else if constexpr (std::is_same_v<T, UInt32>)
|
||||
return ShaderNodes::BasicType::Int1;
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return ShaderNodes::BasicType::Float2;
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return ShaderNodes::BasicType::Float3;
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return ShaderNodes::BasicType::Float4;
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return ShaderNodes::BasicType::Int2;
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return ShaderNodes::BasicType::Int3;
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return ShaderNodes::BasicType::Int4;
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, value);
|
||||
}
|
||||
|
||||
void Constant::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
ShaderExpressionType Cast::GetExpressionType() const
|
||||
{
|
||||
return exprType;
|
||||
}
|
||||
|
||||
void Cast::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType ConditionalExpression::GetExpressionType() const
|
||||
{
|
||||
assert(truePath->GetExpressionType() == falsePath->GetExpressionType());
|
||||
return truePath->GetExpressionType();
|
||||
}
|
||||
|
||||
void ConditionalExpression::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType SwizzleOp::GetExpressionType() const
|
||||
{
|
||||
const ShaderExpressionType& exprType = expression->GetExpressionType();
|
||||
assert(IsBasicType(exprType));
|
||||
|
||||
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
|
||||
}
|
||||
|
||||
void SwizzleOp::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType Sample2D::GetExpressionType() const
|
||||
{
|
||||
return BasicType::Float4;
|
||||
}
|
||||
|
||||
void Sample2D::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
ShaderExpressionType IntrinsicCall::GetExpressionType() const
|
||||
{
|
||||
switch (intrinsic)
|
||||
{
|
||||
case IntrinsicType::CrossProduct:
|
||||
return parameters.front()->GetExpressionType();
|
||||
|
||||
case IntrinsicType::DotProduct:
|
||||
return BasicType::Float1;
|
||||
}
|
||||
|
||||
NazaraAssert(false, "Unhandled builtin");
|
||||
return BasicType::Void;
|
||||
}
|
||||
|
||||
void IntrinsicCall::Visit(ShaderAstVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
#include <Nazara/Shader/ShaderAstNodes.hpp>
|
||||
}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
void ShaderVarVisitorExcept::Visit(ShaderNodes::BuiltinVariable& /*var*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled BuiltinVariable");
|
||||
}
|
||||
|
||||
void ShaderVarVisitorExcept::Visit(ShaderNodes::InputVariable& /*var*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled InputVariable");
|
||||
}
|
||||
|
||||
void ShaderVarVisitorExcept::Visit(ShaderNodes::LocalVariable& /*var*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled LocalVariable");
|
||||
}
|
||||
|
||||
void ShaderVarVisitorExcept::Visit(ShaderNodes::OutputVariable& /*var*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled OutputVariable");
|
||||
}
|
||||
|
||||
void ShaderVarVisitorExcept::Visit(ShaderNodes::ParameterVariable& /*var*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled ParameterVariable");
|
||||
}
|
||||
|
||||
void ShaderVarVisitorExcept::Visit(ShaderNodes::UniformVariable& /*var*/)
|
||||
{
|
||||
throw std::runtime_error("unhandled UniformVariable");
|
||||
}
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderVariables.hpp>
|
||||
#include <Nazara/Shader/ShaderVarVisitor.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderNodes
|
||||
{
|
||||
ShaderNodes::Variable::~Variable() = default;
|
||||
|
||||
VariableType BuiltinVariable::GetType() const
|
||||
{
|
||||
return VariableType::BuiltinVariable;
|
||||
}
|
||||
|
||||
void BuiltinVariable::Visit(ShaderVarVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
VariableType InputVariable::GetType() const
|
||||
{
|
||||
return VariableType::InputVariable;
|
||||
}
|
||||
|
||||
void InputVariable::Visit(ShaderVarVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
VariableType LocalVariable::GetType() const
|
||||
{
|
||||
return VariableType::LocalVariable;
|
||||
}
|
||||
|
||||
void LocalVariable::Visit(ShaderVarVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
VariableType OutputVariable::GetType() const
|
||||
{
|
||||
return VariableType::OutputVariable;
|
||||
}
|
||||
|
||||
void OutputVariable::Visit(ShaderVarVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
VariableType ParameterVariable::GetType() const
|
||||
{
|
||||
return VariableType::ParameterVariable;
|
||||
}
|
||||
|
||||
void ParameterVariable::Visit(ShaderVarVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
|
||||
VariableType UniformVariable::GetType() const
|
||||
{
|
||||
return VariableType::UniformVariable;
|
||||
}
|
||||
|
||||
void UniformVariable::Visit(ShaderVarVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Core/StackVector.hpp>
|
||||
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
|
||||
#include <Nazara/Shader/SpirvSection.hpp>
|
||||
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
|
||||
#include <Nazara/Shader/SpirvExpressionStore.hpp>
|
||||
@@ -12,21 +13,21 @@
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
UInt32 SpirvAstVisitor::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr)
|
||||
UInt32 SpirvAstVisitor::EvaluateExpression(ShaderAst::ExpressionPtr& expr)
|
||||
{
|
||||
Visit(expr);
|
||||
expr->Visit(*this);
|
||||
|
||||
assert(m_resultIds.size() == 1);
|
||||
return PopResultId();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::AccessMember& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
{
|
||||
SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock);
|
||||
PushResultId(accessMemberVisitor.Evaluate(node));
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::AssignOp& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::AssignExpression& node)
|
||||
{
|
||||
UInt32 resultId = EvaluateExpression(node.right);
|
||||
|
||||
@@ -36,20 +37,20 @@ namespace Nz
|
||||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
|
||||
{
|
||||
ShaderExpressionType resultExprType = node.GetExpressionType();
|
||||
ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node);
|
||||
assert(IsBasicType(resultExprType));
|
||||
|
||||
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
|
||||
ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left);
|
||||
assert(IsBasicType(leftExprType));
|
||||
|
||||
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
|
||||
ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right);
|
||||
assert(IsBasicType(rightExprType));
|
||||
|
||||
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
|
||||
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
|
||||
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
|
||||
ShaderAst::BasicType resultType = std::get<ShaderAst::BasicType>(resultExprType);
|
||||
ShaderAst::BasicType leftType = std::get<ShaderAst::BasicType>(leftExprType);
|
||||
ShaderAst::BasicType rightType = std::get<ShaderAst::BasicType>(rightExprType);
|
||||
|
||||
|
||||
UInt32 leftOperand = EvaluateExpression(node.left);
|
||||
@@ -62,308 +63,308 @@ namespace Nz
|
||||
{
|
||||
switch (node.op)
|
||||
{
|
||||
case ShaderNodes::BinaryType::Add:
|
||||
case ShaderAst::BinaryType::Add:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFAdd;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpIAdd;
|
||||
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::Subtract:
|
||||
case ShaderAst::BinaryType::Subtract:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFSub;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpISub;
|
||||
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::Divide:
|
||||
case ShaderAst::BinaryType::Divide:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFDiv;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
return SpirvOp::OpSDiv;
|
||||
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpUDiv;
|
||||
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::CompEq:
|
||||
case ShaderAst::BinaryType::CompEq:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
return SpirvOp::OpLogicalEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFOrdEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpIEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::CompGe:
|
||||
case ShaderAst::BinaryType::CompGe:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFOrdGreaterThan;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
return SpirvOp::OpSGreaterThan;
|
||||
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpUGreaterThan;
|
||||
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::CompGt:
|
||||
case ShaderAst::BinaryType::CompGt:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFOrdGreaterThanEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
return SpirvOp::OpSGreaterThanEqual;
|
||||
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpUGreaterThanEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::CompLe:
|
||||
case ShaderAst::BinaryType::CompLe:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFOrdLessThanEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
return SpirvOp::OpSLessThanEqual;
|
||||
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpULessThanEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::CompLt:
|
||||
case ShaderAst::BinaryType::CompLt:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFOrdLessThan;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
return SpirvOp::OpSLessThan;
|
||||
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpULessThan;
|
||||
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::CompNe:
|
||||
case ShaderAst::BinaryType::CompNe:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
return SpirvOp::OpLogicalNotEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpFOrdNotEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpINotEqual;
|
||||
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Void:
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BinaryType::Multiply:
|
||||
case ShaderAst::BinaryType::Multiply:
|
||||
{
|
||||
switch (leftType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
{
|
||||
switch (rightType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
return SpirvOp::OpFMul;
|
||||
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
swapOperands = true;
|
||||
return SpirvOp::OpVectorTimesScalar;
|
||||
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
swapOperands = true;
|
||||
return SpirvOp::OpMatrixTimesScalar;
|
||||
|
||||
@@ -374,21 +375,21 @@ namespace Nz
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
{
|
||||
switch (rightType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
return SpirvOp::OpVectorTimesScalar;
|
||||
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
return SpirvOp::OpFMul;
|
||||
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return SpirvOp::OpVectorTimesMatrix;
|
||||
|
||||
default:
|
||||
@@ -398,23 +399,23 @@ namespace Nz
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
return SpirvOp::OpIMul;
|
||||
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
{
|
||||
switch (rightType)
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
|
||||
case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
|
||||
case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
|
||||
case ShaderAst::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
|
||||
case ShaderAst::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
|
||||
case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
|
||||
|
||||
default:
|
||||
break;
|
||||
@@ -442,7 +443,7 @@ namespace Nz
|
||||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::Branch& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node)
|
||||
{
|
||||
assert(!node.condStatements.empty());
|
||||
auto& firstCond = node.condStatements.front();
|
||||
@@ -450,7 +451,8 @@ namespace Nz
|
||||
UInt32 previousConditionId = EvaluateExpression(firstCond.condition);
|
||||
SpirvBlock previousContentBlock(m_writer);
|
||||
m_currentBlock = &previousContentBlock;
|
||||
Visit(firstCond.statement);
|
||||
|
||||
firstCond.statement->Visit(*this);
|
||||
|
||||
SpirvBlock mergeBlock(m_writer);
|
||||
m_blocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
|
||||
@@ -458,7 +460,7 @@ namespace Nz
|
||||
std::optional<std::size_t> nextBlock;
|
||||
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
|
||||
{
|
||||
const auto& statement = node.condStatements[statementIndex];
|
||||
auto& statement = node.condStatements[statementIndex];
|
||||
|
||||
SpirvBlock contentBlock(m_writer);
|
||||
|
||||
@@ -469,7 +471,8 @@ namespace Nz
|
||||
previousContentBlock = std::move(contentBlock);
|
||||
|
||||
m_currentBlock = &previousContentBlock;
|
||||
Visit(statement.statement);
|
||||
|
||||
statement.statement->Visit(*this);
|
||||
}
|
||||
|
||||
if (node.elseStatement)
|
||||
@@ -477,7 +480,7 @@ namespace Nz
|
||||
SpirvBlock elseBlock(m_writer);
|
||||
|
||||
m_currentBlock = &elseBlock;
|
||||
Visit(node.elseStatement);
|
||||
node.elseStatement->Visit(*this);
|
||||
|
||||
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice
|
||||
|
||||
@@ -496,16 +499,16 @@ namespace Nz
|
||||
m_currentBlock = &m_blocks.back();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
|
||||
{
|
||||
const ShaderExpressionType& targetExprType = node.exprType;
|
||||
const ShaderAst::ShaderExpressionType& targetExprType = node.targetType;
|
||||
assert(IsBasicType(targetExprType));
|
||||
|
||||
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
|
||||
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
|
||||
|
||||
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
|
||||
|
||||
for (const auto& exprPtr : node.expressions)
|
||||
for (auto& exprPtr : node.expressions)
|
||||
{
|
||||
if (!exprPtr)
|
||||
break;
|
||||
@@ -527,21 +530,21 @@ namespace Nz
|
||||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalExpression& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node)
|
||||
{
|
||||
if (m_writer.IsConditionEnabled(node.conditionName))
|
||||
Visit(node.truePath);
|
||||
node.truePath->Visit(*this);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
node.falsePath->Visit(*this);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalStatement& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ConditionalStatement& node)
|
||||
{
|
||||
if (m_writer.IsConditionEnabled(node.conditionName))
|
||||
Visit(node.statement);
|
||||
node.statement->Visit(*this);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::Constant& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ConstantExpression& node)
|
||||
{
|
||||
std::visit([&] (const auto& value)
|
||||
{
|
||||
@@ -549,46 +552,42 @@ namespace Nz
|
||||
}, node.value);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::DeclareVariable& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node)
|
||||
{
|
||||
if (node.expression)
|
||||
{
|
||||
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
|
||||
|
||||
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
|
||||
m_writer.WriteLocalVariable(localVar.name, EvaluateExpression(node.expression));
|
||||
}
|
||||
if (node.initialExpression)
|
||||
m_writer.WriteLocalVariable(node.varName, EvaluateExpression(node.initialExpression));
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::Discard& /*node*/)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::DiscardStatement& /*node*/)
|
||||
{
|
||||
m_currentBlock->Append(SpirvOp::OpKill);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::ExpressionStatement& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ExpressionStatement& node)
|
||||
{
|
||||
Visit(node.expression);
|
||||
node.expression->Visit(*this);
|
||||
|
||||
PopResultId();
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::Identifier& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock);
|
||||
PushResultId(loadVisitor.Evaluate(node));
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::IntrinsicCall& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
|
||||
{
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
case ShaderNodes::IntrinsicType::DotProduct:
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
{
|
||||
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
|
||||
const ShaderAst::ShaderExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsBasicType(vecExprType));
|
||||
|
||||
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
|
||||
ShaderAst::BasicType vecType = std::get<ShaderAst::BasicType>(vecExprType);
|
||||
|
||||
UInt32 typeId = m_writer.GetTypeId(node.GetComponentType(vecType));
|
||||
UInt32 typeId = m_writer.GetTypeId(ShaderAst::GetComponentType(vecType));
|
||||
|
||||
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
|
||||
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
|
||||
@@ -600,18 +599,18 @@ namespace Nz
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::IntrinsicType::CrossProduct:
|
||||
case ShaderAst::IntrinsicType::CrossProduct:
|
||||
default:
|
||||
throw std::runtime_error("not yet implemented");
|
||||
}
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::NoOp& /*node*/)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::NoOpStatement& /*node*/)
|
||||
{
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::ReturnStatement& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node)
|
||||
{
|
||||
if (node.returnExpr)
|
||||
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
|
||||
@@ -619,30 +618,18 @@ namespace Nz
|
||||
m_currentBlock->Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::Sample2D& node)
|
||||
{
|
||||
UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4);
|
||||
|
||||
UInt32 samplerId = EvaluateExpression(node.sampler);
|
||||
UInt32 coordinatesId = EvaluateExpression(node.coordinates);
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
|
||||
PushResultId(resultId);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::StatementBlock& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::MultiStatement& node)
|
||||
{
|
||||
for (auto& statement : node.statements)
|
||||
Visit(statement);
|
||||
statement->Visit(*this);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node)
|
||||
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
const ShaderExpressionType& targetExprType = node.GetExpressionType();
|
||||
const ShaderAst::ShaderExpressionType& targetExprType = ShaderAst::GetExpressionType(node);
|
||||
assert(IsBasicType(targetExprType));
|
||||
|
||||
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
|
||||
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
|
||||
|
||||
UInt32 exprResultId = EvaluateExpression(node.expression);
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
@@ -666,7 +653,7 @@ namespace Nz
|
||||
// Extract a single component from the vector
|
||||
assert(node.componentCount == 1);
|
||||
|
||||
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
|
||||
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderAst::SwizzleComponent::First) );
|
||||
}
|
||||
|
||||
PushResultId(resultId);
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/SpirvConstantCache.hpp>
|
||||
#include <Nazara/Shader/ShaderAst.hpp>
|
||||
#include <Nazara/Shader/SpirvSection.hpp>
|
||||
#include <Nazara/Utility/FieldOffsets.hpp>
|
||||
#include <tsl/ordered_map.h>
|
||||
@@ -536,7 +535,7 @@ namespace Nz
|
||||
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i>)
|
||||
{
|
||||
return ConstantComposite{
|
||||
BuildType((std::is_same_v<T, Vector2f>) ? ShaderNodes::BasicType::Float2 : ShaderNodes::BasicType::Int2),
|
||||
BuildType((std::is_same_v<T, Vector2f>) ? ShaderAst::BasicType::Float2 : ShaderAst::BasicType::Int2),
|
||||
{
|
||||
BuildConstant(arg.x),
|
||||
BuildConstant(arg.y)
|
||||
@@ -546,7 +545,7 @@ namespace Nz
|
||||
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i>)
|
||||
{
|
||||
return ConstantComposite{
|
||||
BuildType((std::is_same_v<T, Vector3f>) ? ShaderNodes::BasicType::Float3 : ShaderNodes::BasicType::Int3),
|
||||
BuildType((std::is_same_v<T, Vector3f>) ? ShaderAst::BasicType::Float3 : ShaderAst::BasicType::Int3),
|
||||
{
|
||||
BuildConstant(arg.x),
|
||||
BuildConstant(arg.y),
|
||||
@@ -557,7 +556,7 @@ namespace Nz
|
||||
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i>)
|
||||
{
|
||||
return ConstantComposite{
|
||||
BuildType((std::is_same_v<T, Vector4f>) ? ShaderNodes::BasicType::Float4 : ShaderNodes::BasicType::Int4),
|
||||
BuildType((std::is_same_v<T, Vector4f>) ? ShaderAst::BasicType::Float4 : ShaderAst::BasicType::Int4),
|
||||
{
|
||||
BuildConstant(arg.x),
|
||||
BuildConstant(arg.y),
|
||||
@@ -571,7 +570,7 @@ namespace Nz
|
||||
}, value));
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(SpirvConstantCache::Pointer{
|
||||
SpirvConstantCache::BuildType(type),
|
||||
@@ -579,55 +578,55 @@ namespace Nz
|
||||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
|
||||
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>(SpirvConstantCache::Pointer{
|
||||
SpirvConstantCache::BuildType(shader, type),
|
||||
SpirvConstantCache::BuildType(type),
|
||||
storageClass
|
||||
});
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderNodes::BasicType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::BasicType& type) -> TypePtr
|
||||
{
|
||||
return std::make_shared<Type>([&]() -> AnyType
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case ShaderNodes::BasicType::Boolean:
|
||||
case ShaderAst::BasicType::Boolean:
|
||||
return Bool{};
|
||||
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderAst::BasicType::Float1:
|
||||
return Float{ 32 };
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
case ShaderAst::BasicType::Int1:
|
||||
return Integer{ 32, true };
|
||||
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
case ShaderNodes::BasicType::Float4:
|
||||
case ShaderNodes::BasicType::Int2:
|
||||
case ShaderNodes::BasicType::Int3:
|
||||
case ShaderNodes::BasicType::Int4:
|
||||
case ShaderNodes::BasicType::UInt2:
|
||||
case ShaderNodes::BasicType::UInt3:
|
||||
case ShaderNodes::BasicType::UInt4:
|
||||
case ShaderAst::BasicType::Float2:
|
||||
case ShaderAst::BasicType::Float3:
|
||||
case ShaderAst::BasicType::Float4:
|
||||
case ShaderAst::BasicType::Int2:
|
||||
case ShaderAst::BasicType::Int3:
|
||||
case ShaderAst::BasicType::Int4:
|
||||
case ShaderAst::BasicType::UInt2:
|
||||
case ShaderAst::BasicType::UInt3:
|
||||
case ShaderAst::BasicType::UInt4:
|
||||
{
|
||||
auto vecType = BuildType(ShaderNodes::Node::GetComponentType(type));
|
||||
UInt32 componentCount = ShaderNodes::Node::GetComponentCount(type);
|
||||
auto vecType = BuildType(ShaderAst::GetComponentType(type));
|
||||
UInt32 componentCount = ShaderAst::GetComponentCount(type);
|
||||
|
||||
return Vector{ vecType, componentCount };
|
||||
}
|
||||
|
||||
case ShaderNodes::BasicType::Mat4x4:
|
||||
return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u };
|
||||
case ShaderAst::BasicType::Mat4x4:
|
||||
return Matrix{ BuildType(ShaderAst::BasicType::Float4), 4u };
|
||||
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
case ShaderAst::BasicType::UInt1:
|
||||
return Integer{ 32, false };
|
||||
|
||||
case ShaderNodes::BasicType::Void:
|
||||
case ShaderAst::BasicType::Void:
|
||||
return Void{};
|
||||
|
||||
case ShaderNodes::BasicType::Sampler2D:
|
||||
case ShaderAst::BasicType::Sampler2D:
|
||||
{
|
||||
auto imageType = Image{
|
||||
{}, //< qualifier
|
||||
@@ -635,7 +634,7 @@ namespace Nz
|
||||
{}, //< sampled
|
||||
SpirvDim::Dim2D, //< dim
|
||||
SpirvImageFormat::Unknown, //< format
|
||||
BuildType(ShaderNodes::BasicType::Float1), //< sampledType
|
||||
BuildType(ShaderAst::BasicType::Float1), //< sampledType
|
||||
false, //< arrayed,
|
||||
false //< multisampled
|
||||
};
|
||||
@@ -648,16 +647,16 @@ namespace Nz
|
||||
}());
|
||||
}
|
||||
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst& shader, const ShaderExpressionType& type) -> TypePtr
|
||||
auto SpirvConstantCache::BuildType(const ShaderAst::ShaderExpressionType& type) -> TypePtr
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> TypePtr
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
|
||||
if constexpr (std::is_same_v<T, ShaderAst::BasicType>)
|
||||
return BuildType(arg);
|
||||
else if constexpr (std::is_same_v<T, std::string>)
|
||||
{
|
||||
// Register struct members type
|
||||
/*// Register struct members type
|
||||
const auto& structs = shader.GetStructs();
|
||||
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
|
||||
if (it == structs.end())
|
||||
@@ -675,7 +674,8 @@ namespace Nz
|
||||
sMembers.type = BuildType(shader, member.type);
|
||||
}
|
||||
|
||||
return std::make_shared<Type>(std::move(sType));
|
||||
return std::make_shared<Type>(std::move(sType));*/
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace Nz
|
||||
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
|
||||
}
|
||||
|
||||
UInt32 SpirvExpressionLoad::Evaluate(ShaderNodes::Expression& node)
|
||||
UInt32 SpirvExpressionLoad::Evaluate(ShaderAst::Expression& node)
|
||||
{
|
||||
node.Visit(*this);
|
||||
|
||||
@@ -41,7 +41,7 @@ namespace Nz
|
||||
}, m_value);
|
||||
}
|
||||
|
||||
void SpirvExpressionLoad::Visit(ShaderNodes::AccessMember& node)
|
||||
/*void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
{
|
||||
Visit(node.structExpr);
|
||||
|
||||
@@ -49,6 +49,8 @@ namespace Nz
|
||||
{
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
ShaderAst::ShaderExpressionType exprType = GetExpressionType(node.structExpr);
|
||||
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
|
||||
UInt32 typeId = m_writer.GetTypeId(node.exprType);
|
||||
@@ -87,40 +89,15 @@ namespace Nz
|
||||
throw std::runtime_error("an internal error occurred");
|
||||
}
|
||||
}, m_value);
|
||||
}
|
||||
}*/
|
||||
|
||||
void SpirvExpressionLoad::Visit(ShaderNodes::Identifier& node)
|
||||
void SpirvExpressionLoad::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
Visit(node.var);
|
||||
}
|
||||
|
||||
void SpirvExpressionLoad::Visit(ShaderNodes::InputVariable& var)
|
||||
{
|
||||
auto inputVar = m_writer.GetInputVariable(var.name);
|
||||
|
||||
if (auto resultIdOpt = m_writer.ReadVariable(inputVar, SpirvWriter::OnlyCache{}))
|
||||
m_value = Value{ *resultIdOpt };
|
||||
if (node.identifier == "d")
|
||||
m_value = Value{ m_writer.ReadLocalVariable(node.identifier) };
|
||||
else
|
||||
m_value = Pointer{ SpirvStorageClass::Input, inputVar.varId, inputVar.typeId };
|
||||
}
|
||||
m_value = Value{ m_writer.ReadParameterVariable(node.identifier) };
|
||||
|
||||
void SpirvExpressionLoad::Visit(ShaderNodes::LocalVariable& var)
|
||||
{
|
||||
m_value = Value{ m_writer.ReadLocalVariable(var.name) };
|
||||
}
|
||||
|
||||
void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& var)
|
||||
{
|
||||
m_value = Value{ m_writer.ReadParameterVariable(var.name) };
|
||||
}
|
||||
|
||||
void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var)
|
||||
{
|
||||
auto uniformVar = m_writer.GetUniformVariable(var.name);
|
||||
|
||||
if (auto resultIdOpt = m_writer.ReadVariable(uniformVar, SpirvWriter::OnlyCache{}))
|
||||
m_value = Value{ *resultIdOpt };
|
||||
else
|
||||
m_value = Pointer{ SpirvStorageClass::Uniform, uniformVar.varId, uniformVar.typeId };
|
||||
//Visit(node.var);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,9 +15,9 @@ namespace Nz
|
||||
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId)
|
||||
void SpirvExpressionStore::Store(ShaderAst::ExpressionPtr& node, UInt32 resultId)
|
||||
{
|
||||
Visit(node);
|
||||
node->Visit(*this);
|
||||
|
||||
std::visit(overloaded
|
||||
{
|
||||
@@ -36,7 +36,7 @@ namespace Nz
|
||||
}, m_value);
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderNodes::AccessMember& node)
|
||||
/*void SpirvExpressionStore::Visit(ShaderAst::AccessMemberExpression& node)
|
||||
{
|
||||
Visit(node.structExpr);
|
||||
|
||||
@@ -70,34 +70,15 @@ namespace Nz
|
||||
throw std::runtime_error("an internal error occurred");
|
||||
}
|
||||
}, m_value);
|
||||
}
|
||||
}*/
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderNodes::Identifier& node)
|
||||
void SpirvExpressionStore::Visit(ShaderAst::IdentifierExpression& node)
|
||||
{
|
||||
Visit(node.var);
|
||||
m_value = LocalVar{ node.identifier };
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderNodes::SwizzleOp& node)
|
||||
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
throw std::runtime_error("not yet implemented");
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderNodes::BuiltinVariable& var)
|
||||
{
|
||||
const auto& outputVar = m_writer.GetBuiltinVariable(var.entry);
|
||||
|
||||
m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId };
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderNodes::LocalVariable& var)
|
||||
{
|
||||
m_value = LocalVar{ var.name };
|
||||
}
|
||||
|
||||
void SpirvExpressionStore::Visit(ShaderNodes::OutputVariable& var)
|
||||
{
|
||||
const auto& outputVar = m_writer.GetOutputVariable(var.name);
|
||||
|
||||
m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,155 +26,131 @@ namespace Nz
|
||||
{
|
||||
namespace
|
||||
{
|
||||
class PreVisitor : public ShaderAstRecursiveVisitor, public ShaderVarVisitor
|
||||
class PreVisitor : public ShaderAst::AstRecursiveVisitor
|
||||
{
|
||||
public:
|
||||
using BuiltinContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::BuiltinVariable>>;
|
||||
using ExtInstList = std::unordered_set<std::string>;
|
||||
using LocalContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::LocalVariable>>;
|
||||
using ParameterContainer = std::unordered_set< std::shared_ptr<const ShaderNodes::ParameterVariable>>;
|
||||
using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>;
|
||||
|
||||
PreVisitor(const ShaderAst& shader, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_shader(shader),
|
||||
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_cache(cache),
|
||||
m_conditions(conditions),
|
||||
m_constantCache(constantCache)
|
||||
{
|
||||
}
|
||||
|
||||
using ShaderAstRecursiveVisitor::Visit;
|
||||
using ShaderVarVisitor::Visit;
|
||||
|
||||
void Visit(ShaderNodes::AccessMember& node) override
|
||||
void Visit(ShaderAst::AccessMemberExpression& node) override
|
||||
{
|
||||
for (std::size_t index : node.memberIndices)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));
|
||||
/*for (std::size_t index : node.memberIdentifiers)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalExpression& node) override
|
||||
void Visit(ShaderAst::ConditionalExpression& node) override
|
||||
{
|
||||
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
|
||||
Visit(node.truePath);
|
||||
else
|
||||
Visit(node.falsePath);
|
||||
Visit(node.falsePath);*/
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ConditionalStatement& node) override
|
||||
void Visit(ShaderAst::ConditionalStatement& node) override
|
||||
{
|
||||
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
|
||||
assert(conditionIndex != ShaderAst::InvalidCondition);
|
||||
|
||||
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
|
||||
Visit(node.statement);
|
||||
Visit(node.statement);*/
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Constant& node) override
|
||||
void Visit(ShaderAst::ConstantExpression& node) override
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
|
||||
}, node.value);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::DeclareVariable& node) override
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
Visit(node.variable);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType));
|
||||
for (auto& parameter : node.parameters)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::Identifier& node) override
|
||||
void Visit(ShaderAst::DeclareStructStatement& node) override
|
||||
{
|
||||
Visit(node.var);
|
||||
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
for (auto& field : node.description.members)
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(field.type));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::IntrinsicCall& node) override
|
||||
void Visit(ShaderAst::DeclareVariableStatement& node) override
|
||||
{
|
||||
ShaderAstRecursiveVisitor::Visit(node);
|
||||
variableTypes.insert(node.varType);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IdentifierExpression& node) override
|
||||
{
|
||||
variableTypes.insert(GetExpressionType(node, m_cache));
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IntrinsicExpression& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
// Require GLSL.std.450
|
||||
case ShaderNodes::IntrinsicType::CrossProduct:
|
||||
case ShaderAst::IntrinsicType::CrossProduct:
|
||||
extInsts.emplace("GLSL.std.450");
|
||||
break;
|
||||
|
||||
// Part of SPIR-V core
|
||||
case ShaderNodes::IntrinsicType::DotProduct:
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::BuiltinVariable& var) override
|
||||
{
|
||||
builtinVars.insert(std::static_pointer_cast<const ShaderNodes::BuiltinVariable>(var.shared_from_this()));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::InputVariable& /*var*/) override
|
||||
{
|
||||
/* Handled by ShaderAst */
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::LocalVariable& var) override
|
||||
{
|
||||
localVars.insert(std::static_pointer_cast<const ShaderNodes::LocalVariable>(var.shared_from_this()));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::OutputVariable& /*var*/) override
|
||||
{
|
||||
/* Handled by ShaderAst */
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::ParameterVariable& var) override
|
||||
{
|
||||
paramVars.insert(std::static_pointer_cast<const ShaderNodes::ParameterVariable>(var.shared_from_this()));
|
||||
}
|
||||
|
||||
void Visit(ShaderNodes::UniformVariable& /*var*/) override
|
||||
{
|
||||
/* Handled by ShaderAst */
|
||||
}
|
||||
|
||||
BuiltinContainer builtinVars;
|
||||
ExtInstList extInsts;
|
||||
LocalContainer localVars;
|
||||
ParameterContainer paramVars;
|
||||
LocalContainer variableTypes;
|
||||
|
||||
private:
|
||||
const ShaderAst& m_shader;
|
||||
ShaderAst::AstCache* m_cache;
|
||||
const SpirvWriter::States& m_conditions;
|
||||
SpirvConstantCache& m_constantCache;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
constexpr ShaderNodes::BasicType GetBasicType()
|
||||
constexpr ShaderAst::BasicType GetBasicType()
|
||||
{
|
||||
if constexpr (std::is_same_v<T, bool>)
|
||||
return ShaderNodes::BasicType::Boolean;
|
||||
return ShaderAst::BasicType::Boolean;
|
||||
else if constexpr (std::is_same_v<T, float>)
|
||||
return(ShaderNodes::BasicType::Float1);
|
||||
return(ShaderAst::BasicType::Float1);
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
return(ShaderNodes::BasicType::Int1);
|
||||
return(ShaderAst::BasicType::Int1);
|
||||
else if constexpr (std::is_same_v<T, Vector2f>)
|
||||
return(ShaderNodes::BasicType::Float2);
|
||||
return(ShaderAst::BasicType::Float2);
|
||||
else if constexpr (std::is_same_v<T, Vector3f>)
|
||||
return(ShaderNodes::BasicType::Float3);
|
||||
return(ShaderAst::BasicType::Float3);
|
||||
else if constexpr (std::is_same_v<T, Vector4f>)
|
||||
return(ShaderNodes::BasicType::Float4);
|
||||
return(ShaderAst::BasicType::Float4);
|
||||
else if constexpr (std::is_same_v<T, Vector2i32>)
|
||||
return(ShaderNodes::BasicType::Int2);
|
||||
return(ShaderAst::BasicType::Int2);
|
||||
else if constexpr (std::is_same_v<T, Vector3i32>)
|
||||
return(ShaderNodes::BasicType::Int3);
|
||||
return(ShaderAst::BasicType::Int3);
|
||||
else if constexpr (std::is_same_v<T, Vector4i32>)
|
||||
return(ShaderNodes::BasicType::Int4);
|
||||
return(ShaderAst::BasicType::Int4);
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "unhandled type");
|
||||
}
|
||||
@@ -198,7 +174,7 @@ namespace Nz
|
||||
tsl::ordered_map<std::string, ExtVar> parameterIds;
|
||||
tsl::ordered_map<std::string, ExtVar> uniformIds;
|
||||
std::unordered_map<std::string, UInt32> extensionInstructions;
|
||||
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
|
||||
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
|
||||
std::unordered_map<std::string, UInt32> varToResult;
|
||||
std::vector<Func> funcs;
|
||||
std::vector<SpirvBlock> functionBlocks;
|
||||
@@ -219,13 +195,12 @@ namespace Nz
|
||||
{
|
||||
}
|
||||
|
||||
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader, const States& conditions)
|
||||
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
|
||||
{
|
||||
std::string error;
|
||||
if (!ValidateShader(shader, &error))
|
||||
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.shader = &shader;
|
||||
m_context.states = &conditions;
|
||||
|
||||
State state;
|
||||
@@ -235,23 +210,19 @@ namespace Nz
|
||||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
std::vector<ShaderNodes::StatementPtr> functionStatements;
|
||||
std::vector<ShaderAst::StatementPtr> functionStatements;
|
||||
|
||||
ShaderAstCloner cloner;
|
||||
|
||||
PreVisitor preVisitor(shader, conditions, state.constantTypeCache);
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
functionStatements.emplace_back(cloner.Clone(func.statement));
|
||||
preVisitor.Visit(func.statement);
|
||||
}
|
||||
ShaderAst::AstCloner cloner;
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
|
||||
shader->Visit(preVisitor);
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
state.extensionInstructions[extInst] = AllocateResultId();
|
||||
|
||||
// Register all types
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
/*for (const auto& func : shader.GetFunctions())
|
||||
{
|
||||
RegisterType(func.returnType);
|
||||
for (const auto& param : func.parameters)
|
||||
@@ -270,8 +241,8 @@ namespace Nz
|
||||
for (const auto& func : shader.GetFunctions())
|
||||
RegisterFunctionType(func.returnType, func.parameters);
|
||||
|
||||
for (const auto& local : preVisitor.localVars)
|
||||
RegisterType(local->type);
|
||||
for (const auto& type : preVisitor.variableTypes)
|
||||
RegisterType(type);
|
||||
|
||||
for (const auto& builtin : preVisitor.builtinVars)
|
||||
RegisterType(builtin->type);
|
||||
@@ -283,7 +254,7 @@ namespace Nz
|
||||
SpirvBuiltIn builtinDecoration;
|
||||
switch (builtin->entry)
|
||||
{
|
||||
case ShaderNodes::BuiltinEntry::VertexPosition:
|
||||
case ShaderAst::BuiltinEntry::VertexPosition:
|
||||
variable.debugName = "builtin_VertexPosition";
|
||||
variable.storageClass = SpirvStorageClass::Output;
|
||||
|
||||
@@ -294,10 +265,10 @@ namespace Nz
|
||||
throw std::runtime_error("unexpected builtin type");
|
||||
}
|
||||
|
||||
const ShaderExpressionType& builtinExprType = builtin->type;
|
||||
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
|
||||
assert(IsBasicType(builtinExprType));
|
||||
|
||||
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);
|
||||
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
|
||||
|
||||
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
|
||||
|
||||
@@ -420,7 +391,7 @@ namespace Nz
|
||||
|
||||
if (!state.functionBlocks.back().IsTerminated())
|
||||
{
|
||||
assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void));
|
||||
assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void));
|
||||
state.functionBlocks.back().Append(SpirvOp::OpReturn);
|
||||
}
|
||||
|
||||
@@ -475,14 +446,14 @@ namespace Nz
|
||||
|
||||
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
|
||||
}
|
||||
}*/
|
||||
|
||||
std::vector<UInt32> ret;
|
||||
MergeSections(ret, state.header);
|
||||
/*MergeSections(ret, state.header);
|
||||
MergeSections(ret, state.debugInfo);
|
||||
MergeSections(ret, state.annotations);
|
||||
MergeSections(ret, state.constants);
|
||||
MergeSections(ret, state.instructions);
|
||||
MergeSections(ret, state.instructions);*/
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -516,16 +487,16 @@ namespace Nz
|
||||
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
|
||||
}
|
||||
|
||||
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
|
||||
{
|
||||
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
|
||||
parameterTypes.reserve(parameters.size());
|
||||
|
||||
for (const auto& parameter : parameters)
|
||||
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function));
|
||||
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function));
|
||||
|
||||
return SpirvConstantCache::Function{
|
||||
SpirvConstantCache::BuildType(*m_context.shader, retType),
|
||||
SpirvConstantCache::BuildType(retType),
|
||||
std::move(parameterTypes)
|
||||
};
|
||||
}
|
||||
@@ -535,12 +506,12 @@ namespace Nz
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) });
|
||||
}
|
||||
|
||||
auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar&
|
||||
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
|
||||
{
|
||||
auto it = m_currentState->builtinIds.find(builtin);
|
||||
assert(it != m_currentState->builtinIds.end());
|
||||
@@ -572,14 +543,14 @@ namespace Nz
|
||||
return it.value();
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const
|
||||
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const
|
||||
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const
|
||||
{
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type));
|
||||
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
|
||||
@@ -673,20 +644,20 @@ namespace Nz
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
|
||||
UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) });
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)
|
||||
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass)
|
||||
{
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
|
||||
}
|
||||
|
||||
UInt32 SpirvWriter::RegisterType(ShaderExpressionType type)
|
||||
UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type)
|
||||
{
|
||||
assert(m_currentState);
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type));
|
||||
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
|
||||
}
|
||||
|
||||
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)
|
||||
|
||||
Reference in New Issue
Block a user