Rework shader AST (WIP)

This commit is contained in:
Jérôme Leclercq
2021-03-10 11:18:13 +01:00
parent b320b5b44e
commit fed7370e77
73 changed files with 2721 additions and 4312 deletions

View File

@@ -167,8 +167,8 @@ namespace Nz
auto& fragmentShader = settings.shaders[UnderlyingCast(ShaderStageType::Fragment)];
auto& vertexShader = settings.shaders[UnderlyingCast(ShaderStageType::Vertex)];
fragmentShader = std::make_shared<UberShader>(UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader)));
vertexShader = std::make_shared<UberShader>(UnserializeShader(r_vertexShader, sizeof(r_vertexShader)));
fragmentShader = std::make_shared<UberShader>(ShaderAst::UnserializeShader(r_fragmentShader, sizeof(r_fragmentShader)));
vertexShader = std::make_shared<UberShader>(ShaderAst::UnserializeShader(r_vertexShader, sizeof(r_vertexShader)));
// Conditions

View File

@@ -5,17 +5,17 @@
#include <Nazara/Graphics/UberShader.hpp>
#include <Nazara/Graphics/Graphics.hpp>
#include <Nazara/Renderer/RenderDevice.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <limits>
#include <stdexcept>
#include <Nazara/Graphics/Debug.hpp>
namespace Nz
{
UberShader::UberShader(ShaderAst shaderAst) :
UberShader::UberShader(ShaderAst::StatementPtr shaderAst) :
m_shaderAst(std::move(shaderAst))
{
std::size_t conditionCount = m_shaderAst.GetConditionCount();
//std::size_t conditionCount = m_shaderAst.GetConditionCount();
std::size_t conditionCount = 0;
if (conditionCount >= 64)
throw std::runtime_error("Too many conditions");
@@ -27,10 +27,10 @@ namespace Nz
UInt64 UberShader::GetConditionFlagByName(const std::string_view& condition) const
{
std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition);
/*std::size_t conditionIndex = m_shaderAst.FindConditionByName(condition);
if (conditionIndex != ShaderAst::InvalidCondition)
return SetBit<UInt64>(0, conditionIndex);
else
else*/
return 0;
}

View File

@@ -8,6 +8,7 @@
#include <Nazara/Math/Algorithm.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderAstCloner.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <optional>
@@ -20,64 +21,67 @@ namespace Nz
{
static const char* flipYUniformName = "_NzFlipValue";
struct AstAdapter : ShaderAstCloner
struct AstAdapter : ShaderAst::AstCloner
{
void Visit(ShaderNodes::AssignOp& node) override
void Visit(ShaderAst::AssignExpression& node) override
{
if (!flipYPosition)
return AstCloner::Visit(node);
if (node.left->GetType() != ShaderAst::NodeType::IdentifierExpression)
return AstCloner::Visit(node);
/*
FIXME:
const auto& identifier = static_cast<const ShaderAst::Identifier&>(*node.left);
if (identifier.var->GetType() != ShaderAst::VariableType::BuiltinVariable)
return ShaderAstCloner::Visit(node);
if (node.left->GetType() != ShaderNodes::NodeType::Identifier)
const auto& builtinVar = static_cast<const ShaderAst::BuiltinVariable&>(*identifier.var);
if (builtinVar.entry != ShaderAst::BuiltinEntry::VertexPosition)
return ShaderAstCloner::Visit(node);
const auto& identifier = static_cast<const ShaderNodes::Identifier&>(*node.left);
if (identifier.var->GetType() != ShaderNodes::VariableType::BuiltinVariable)
return ShaderAstCloner::Visit(node);
const auto& builtinVar = static_cast<const ShaderNodes::BuiltinVariable&>(*identifier.var);
if (builtinVar.entry != ShaderNodes::BuiltinEntry::VertexPosition)
return ShaderAstCloner::Visit(node);
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderNodes::BasicType::Float1);
auto flipVar = ShaderBuilder::Uniform(flipYUniformName, ShaderAst::BasicType::Float1);
auto oneConstant = ShaderBuilder::Constant(1.f);
auto fixYValue = ShaderBuilder::Cast<ShaderNodes::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
auto fixYValue = ShaderBuilder::Cast<ShaderAst::BasicType::Float4>(oneConstant, ShaderBuilder::Identifier(flipVar), oneConstant, oneConstant);
auto mulFix = ShaderBuilder::Multiply(CloneExpression(node.right), fixYValue);
PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));
PushExpression(ShaderAst::AssignOp::Build(node.op, CloneExpression(node.left), mulFix));*/
}
bool flipYPosition = false;
};
}
struct GlslWriter::State
{
const States* states = nullptr;
ShaderAst::AstCache cache;
std::stringstream stream;
unsigned int indentLevel = 0;
};
GlslWriter::GlslWriter() :
m_currentState(nullptr)
{
}
std::string GlslWriter::Generate(const ShaderAst& inputShader, const States& conditions)
std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
{
const ShaderAst* selectedShader = &inputShader;
/*const ShaderAst* selectedShader = &inputShader;
std::optional<ShaderAst> modifiedShader;
if (inputShader.GetStage() == ShaderStageType::Vertex && m_environment.flipYPosition)
{
modifiedShader.emplace(inputShader);
modifiedShader->AddUniform(flipYUniformName, ShaderNodes::BasicType::Float1);
modifiedShader->AddUniform(flipYUniformName, ShaderAst::BasicType::Float1);
selectedShader = &modifiedShader.value();
}
const ShaderAst& shader = *selectedShader;
std::string error;
if (!ValidateShader(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.states = &conditions;
m_context.shader = &shader;
}*/
State state;
m_currentState = &state;
CallOnExit onExit([this]()
@@ -85,6 +89,10 @@ namespace Nz
m_currentState = nullptr;
});
std::string error;
if (!ShaderAst::ValidateAst(shader, &error, &state.cache))
throw std::runtime_error("Invalid shader AST: " + error);
unsigned int glslVersion;
if (m_environment.glES)
{
@@ -165,52 +173,7 @@ namespace Nz
AppendLine();
}
// Structures
/*if (shader.GetStructCount() > 0)
{
AppendCommentSection("Structures");
for (const auto& s : shader.GetStructs())
{
Append("struct ");
AppendLine(s.name);
AppendLine("{");
for (const auto& m : s.members)
{
Append("\t");
Append(m.type);
Append(" ");
Append(m.name);
AppendLine(";");
}
AppendLine("};");
AppendLine();
}
}*/
// Global variables (uniforms, input and outputs)
const char* inKeyword = (glslVersion >= 130) ? "in" : "varying";
const char* outKeyword = (glslVersion >= 130) ? "out" : "varying";
DeclareVariables(shader, shader.GetUniforms(), "uniform", "Uniforms");
DeclareVariables(shader, shader.GetInputs(), inKeyword, "Inputs");
DeclareVariables(shader, shader.GetOutputs(), outKeyword, "Outputs");
std::size_t functionCount = shader.GetFunctionCount();
if (functionCount > 1)
{
AppendCommentSection("Prototypes");
for (const auto& func : shader.GetFunctions())
{
if (func.name != "main")
{
AppendFunctionPrototype(func);
AppendLine(";");
}
}
}
for (const auto& func : shader.GetFunctions())
AppendFunction(func);
shader->Visit(*this);
return state.stream.str();
}
@@ -225,7 +188,7 @@ namespace Nz
return flipYUniformName;
}
void GlslWriter::Append(ShaderExpressionType type)
void GlslWriter::Append(ShaderAst::ShaderExpressionType type)
{
std::visit([&](auto&& arg)
{
@@ -233,49 +196,57 @@ namespace Nz
}, type);
}
void GlslWriter::Append(ShaderNodes::BuiltinEntry builtin)
void GlslWriter::Append(ShaderAst::BuiltinEntry builtin)
{
switch (builtin)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
case ShaderAst::BuiltinEntry::VertexPosition:
Append("gl_Position");
break;
}
}
void GlslWriter::Append(ShaderNodes::BasicType type)
void GlslWriter::Append(ShaderAst::BasicType type)
{
switch (type)
{
case ShaderNodes::BasicType::Boolean: return Append("bool");
case ShaderNodes::BasicType::Float1: return Append("float");
case ShaderNodes::BasicType::Float2: return Append("vec2");
case ShaderNodes::BasicType::Float3: return Append("vec3");
case ShaderNodes::BasicType::Float4: return Append("vec4");
case ShaderNodes::BasicType::Int1: return Append("int");
case ShaderNodes::BasicType::Int2: return Append("ivec2");
case ShaderNodes::BasicType::Int3: return Append("ivec3");
case ShaderNodes::BasicType::Int4: return Append("ivec4");
case ShaderNodes::BasicType::Mat4x4: return Append("mat4");
case ShaderNodes::BasicType::Sampler2D: return Append("sampler2D");
case ShaderNodes::BasicType::UInt1: return Append("uint");
case ShaderNodes::BasicType::UInt2: return Append("uvec2");
case ShaderNodes::BasicType::UInt3: return Append("uvec3");
case ShaderNodes::BasicType::UInt4: return Append("uvec4");
case ShaderNodes::BasicType::Void: return Append("void");
case ShaderAst::BasicType::Boolean: return Append("bool");
case ShaderAst::BasicType::Float1: return Append("float");
case ShaderAst::BasicType::Float2: return Append("vec2");
case ShaderAst::BasicType::Float3: return Append("vec3");
case ShaderAst::BasicType::Float4: return Append("vec4");
case ShaderAst::BasicType::Int1: return Append("int");
case ShaderAst::BasicType::Int2: return Append("ivec2");
case ShaderAst::BasicType::Int3: return Append("ivec3");
case ShaderAst::BasicType::Int4: return Append("ivec4");
case ShaderAst::BasicType::Mat4x4: return Append("mat4");
case ShaderAst::BasicType::Sampler2D: return Append("sampler2D");
case ShaderAst::BasicType::UInt1: return Append("uint");
case ShaderAst::BasicType::UInt2: return Append("uvec2");
case ShaderAst::BasicType::UInt3: return Append("uvec3");
case ShaderAst::BasicType::UInt4: return Append("uvec4");
case ShaderAst::BasicType::Void: return Append("void");
}
}
void GlslWriter::Append(ShaderNodes::MemoryLayout layout)
void GlslWriter::Append(ShaderAst::MemoryLayout layout)
{
switch (layout)
{
case ShaderNodes::MemoryLayout::Std140:
case ShaderAst::MemoryLayout::Std140:
Append("std140");
break;
}
}
template<typename T>
void GlslWriter::Append(const T& param)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
m_currentState->stream << param;
}
void GlslWriter::AppendCommentSection(const std::string& section)
{
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
@@ -285,67 +256,24 @@ namespace Nz
AppendLine();
}
void GlslWriter::AppendField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers)
void GlslWriter::AppendField(std::size_t scopeId, const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
const auto& structs = m_context.shader->GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
assert(it != structs.end());
const ShaderAst::Struct& s = *it;
assert(*memberIndex < s.members.size());
const auto& member = s.members[*memberIndex];
Append(".");
Append(member.name);
Append(memberIdentifier[0]);
const ShaderAst::AstCache::Identifier* identifier = m_currentState->cache.FindIdentifier(scopeId, structName);
assert(identifier);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
const auto& s = std::get<ShaderAst::StructDescription>(identifier->value);
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
assert(memberIt != s.members.end());
const auto& member = *memberIt;
if (remainingMembers > 1)
{
assert(IsStructType(member.type));
AppendField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
}
}
void GlslWriter::AppendFunction(const ShaderAst::Function& func)
{
NazaraAssert(!m_context.currentFunction, "A function is already being processed");
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
AppendFunctionPrototype(func);
m_context.currentFunction = &func;
CallOnExit onExit([this] ()
{
m_context.currentFunction = nullptr;
});
EnterScope();
{
AstAdapter adapter;
adapter.flipYPosition = m_environment.flipYPosition;
Visit(adapter.Clone(func.statement));
}
LeaveScope();
}
void GlslWriter::AppendFunctionPrototype(const ShaderAst::Function& func)
{
Append(func.returnType);
Append(" ");
Append(func.name);
Append("(");
for (std::size_t i = 0; i < func.parameters.size(); ++i)
{
if (i != 0)
Append(", ");
Append(func.parameters[i].type);
Append(" ");
Append(func.parameters[i].name);
}
Append(")\n");
AppendField(scopeId, std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
}
void GlslWriter::AppendLine(const std::string& txt)
@@ -372,44 +300,46 @@ namespace Nz
AppendLine("}");
}
void GlslWriter::Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired)
void GlslWriter::Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired)
{
bool enclose = encloseIfRequired && (GetExpressionCategory(expr) != ShaderNodes::ExpressionCategory::LValue);
bool enclose = encloseIfRequired && (GetExpressionCategory(*expr) != ShaderAst::ExpressionCategory::LValue);
if (enclose)
Append("(");
ShaderAstVisitor::Visit(expr);
expr->Visit(*this);
if (enclose)
Append(")");
}
void GlslWriter::Visit(ShaderNodes::AccessMember& node)
void GlslWriter::Visit(ShaderAst::AccessMemberExpression& node)
{
Visit(node.structExpr, true);
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
const ShaderAst::ShaderExpressionType& exprType = GetExpressionType(*node.structExpr, &m_currentState->cache);
assert(IsStructType(exprType));
AppendField(std::get<std::string>(exprType), node.memberIndices.data(), node.memberIndices.size());
std::size_t scopeId = m_currentState->cache.GetScopeId(&node);
AppendField(scopeId, std::get<std::string>(exprType), node.memberIdentifiers.data(), node.memberIdentifiers.size());
}
void GlslWriter::Visit(ShaderNodes::AssignOp& node)
void GlslWriter::Visit(ShaderAst::AssignExpression& node)
{
Visit(node.left);
node.left->Visit(*this);
switch (node.op)
{
case ShaderNodes::AssignType::Simple:
case ShaderAst::AssignType::Simple:
Append(" = ");
break;
}
Visit(node.right);
node.left->Visit(*this);
}
void GlslWriter::Visit(ShaderNodes::Branch& node)
void GlslWriter::Visit(ShaderAst::BranchStatement& node)
{
bool first = true;
for (const auto& statement : node.condStatements)
@@ -418,11 +348,11 @@ namespace Nz
Append("else ");
Append("if (");
Visit(statement.condition);
statement.condition->Visit(*this);
AppendLine(")");
EnterScope();
Visit(statement.statement);
statement.statement->Visit(*this);
LeaveScope();
first = false;
@@ -433,41 +363,36 @@ namespace Nz
AppendLine("else");
EnterScope();
Visit(node.elseStatement);
node.elseStatement->Visit(*this);
LeaveScope();
}
}
void GlslWriter::Visit(ShaderNodes::BinaryOp& node)
void GlslWriter::Visit(ShaderAst::BinaryExpression& node)
{
Visit(node.left, true);
switch (node.op)
{
case ShaderNodes::BinaryType::Add: Append(" + "); break;
case ShaderNodes::BinaryType::Subtract: Append(" - "); break;
case ShaderNodes::BinaryType::Multiply: Append(" * "); break;
case ShaderNodes::BinaryType::Divide: Append(" / "); break;
case ShaderAst::BinaryType::Add: Append(" + "); break;
case ShaderAst::BinaryType::Subtract: Append(" - "); break;
case ShaderAst::BinaryType::Multiply: Append(" * "); break;
case ShaderAst::BinaryType::Divide: Append(" / "); break;
case ShaderNodes::BinaryType::CompEq: Append(" == "); break;
case ShaderNodes::BinaryType::CompGe: Append(" >= "); break;
case ShaderNodes::BinaryType::CompGt: Append(" > "); break;
case ShaderNodes::BinaryType::CompLe: Append(" <= "); break;
case ShaderNodes::BinaryType::CompLt: Append(" < "); break;
case ShaderNodes::BinaryType::CompNe: Append(" != "); break;
case ShaderAst::BinaryType::CompEq: Append(" == "); break;
case ShaderAst::BinaryType::CompGe: Append(" >= "); break;
case ShaderAst::BinaryType::CompGt: Append(" > "); break;
case ShaderAst::BinaryType::CompLe: Append(" <= "); break;
case ShaderAst::BinaryType::CompLt: Append(" < "); break;
case ShaderAst::BinaryType::CompNe: Append(" != "); break;
}
Visit(node.right, true);
}
void GlslWriter::Visit(ShaderNodes::BuiltinVariable& var)
void GlslWriter::Visit(ShaderAst::CastExpression& node)
{
Append(var.entry);
}
void GlslWriter::Visit(ShaderNodes::Cast& node)
{
Append(node.exprType);
Append(node.targetType);
Append("(");
bool first = true;
@@ -479,34 +404,34 @@ namespace Nz
if (!first)
m_currentState->stream << ", ";
Visit(exprPtr);
exprPtr->Visit(*this);
first = false;
}
Append(")");
}
void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node)
void GlslWriter::Visit(ShaderAst::ConditionalExpression& node)
{
std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
Visit(node.truePath);
else
Visit(node.falsePath);
Visit(node.falsePath);*/
}
void GlslWriter::Visit(ShaderNodes::ConditionalStatement& node)
void GlslWriter::Visit(ShaderAst::ConditionalStatement& node)
{
std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_context.states->enabledConditions, conditionIndex))
Visit(node.statement);
Visit(node.statement);*/
}
void GlslWriter::Visit(ShaderNodes::Constant& node)
void GlslWriter::Visit(ShaderAst::ConstantExpression& node)
{
std::visit([&](auto&& arg)
{
@@ -530,54 +455,74 @@ namespace Nz
}, node.value);
}
void GlslWriter::Visit(ShaderNodes::DeclareVariable& node)
void GlslWriter::Visit(ShaderAst::DeclareFunctionStatement& node)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
Append(localVar.type);
Append(node.returnType);
Append(" ");
Append(localVar.name);
if (node.expression)
Append(node.name);
Append("(");
for (std::size_t i = 0; i < node.parameters.size(); ++i)
{
if (i != 0)
Append(", ");
Append(node.parameters[i].type);
Append(" ");
Append(node.parameters[i].name);
}
Append(")\n");
EnterScope();
{
AstAdapter adapter;
adapter.flipYPosition = m_environment.flipYPosition;
for (auto& statement : node.statements)
adapter.Clone(statement)->Visit(*this);
}
LeaveScope();
}
void GlslWriter::Visit(ShaderAst::DeclareVariableStatement& node)
{
Append(node.varType);
Append(" ");
Append(node.varName);
if (node.initialExpression)
{
Append(" = ");
Visit(node.expression);
node.initialExpression->Visit(*this);
}
AppendLine(";");
}
void GlslWriter::Visit(ShaderNodes::Discard& /*node*/)
void GlslWriter::Visit(ShaderAst::DiscardStatement& /*node*/)
{
Append("discard;");
}
void GlslWriter::Visit(ShaderNodes::ExpressionStatement& node)
void GlslWriter::Visit(ShaderAst::ExpressionStatement& node)
{
Visit(node.expression);
node.expression->Visit(*this);
Append(";");
}
void GlslWriter::Visit(ShaderNodes::Identifier& node)
void GlslWriter::Visit(ShaderAst::IdentifierExpression& node)
{
Visit(node.var);
Append(node.identifier);
}
void GlslWriter::Visit(ShaderNodes::InputVariable& var)
{
Append(var.name);
}
void GlslWriter::Visit(ShaderNodes::IntrinsicCall& node)
void GlslWriter::Visit(ShaderAst::IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderAst::IntrinsicType::CrossProduct:
Append("cross");
break;
case ShaderNodes::IntrinsicType::DotProduct:
case ShaderAst::IntrinsicType::DotProduct:
Append("dot");
break;
}
@@ -588,67 +533,43 @@ namespace Nz
if (i != 0)
Append(", ");
Visit(node.parameters[i]);
node.parameters[i]->Visit(*this);
}
Append(")");
}
void GlslWriter::Visit(ShaderNodes::LocalVariable& var)
void GlslWriter::Visit(ShaderAst::MultiStatement& node)
{
Append(var.name);
bool first = true;
for (const ShaderAst::StatementPtr& statement : node.statements)
{
if (!first && statement->GetType() != ShaderAst::NodeType::NoOpStatement)
AppendLine();
statement->Visit(*this);
first = false;
}
}
void GlslWriter::Visit(ShaderNodes::NoOp& /*node*/)
void GlslWriter::Visit(ShaderAst::NoOpStatement& /*node*/)
{
/* nothing to do */
}
void GlslWriter::Visit(ShaderNodes::ParameterVariable& var)
{
Append(var.name);
}
void GlslWriter::Visit(ShaderNodes::ReturnStatement& node)
void GlslWriter::Visit(ShaderAst::ReturnStatement& node)
{
if (node.returnExpr)
{
Append("return ");
Visit(node.returnExpr);
node.returnExpr->Visit(*this);
Append(";");
}
else
Append("return;");
}
void GlslWriter::Visit(ShaderNodes::OutputVariable& var)
{
Append(var.name);
}
void GlslWriter::Visit(ShaderNodes::Sample2D& node)
{
Append("texture(");
Visit(node.sampler);
Append(", ");
Visit(node.coordinates);
Append(")");
}
void GlslWriter::Visit(ShaderNodes::StatementBlock& node)
{
bool first = true;
for (const ShaderNodes::StatementPtr& statement : node.statements)
{
if (!first && statement->GetType() != ShaderNodes::NodeType::NoOp)
AppendLine();
Visit(statement);
first = false;
}
}
void GlslWriter::Visit(ShaderNodes::SwizzleOp& node)
void GlslWriter::Visit(ShaderAst::SwizzleExpression& node)
{
Visit(node.expression, true);
Append(".");
@@ -657,44 +578,39 @@ namespace Nz
{
switch (node.components[i])
{
case ShaderNodes::SwizzleComponent::First:
case ShaderAst::SwizzleComponent::First:
Append("x");
break;
case ShaderNodes::SwizzleComponent::Second:
case ShaderAst::SwizzleComponent::Second:
Append("y");
break;
case ShaderNodes::SwizzleComponent::Third:
case ShaderAst::SwizzleComponent::Third:
Append("z");
break;
case ShaderNodes::SwizzleComponent::Fourth:
case ShaderAst::SwizzleComponent::Fourth:
Append("w");
break;
}
}
}
void GlslWriter::Visit(ShaderNodes::UniformVariable& var)
bool GlslWriter::HasExplicitBinding(ShaderAst::StatementPtr& shader)
{
Append(var.name);
}
bool GlslWriter::HasExplicitBinding(const ShaderAst& shader)
{
for (const auto& uniform : shader.GetUniforms())
/*for (const auto& uniform : shader.GetUniforms())
{
if (uniform.bindingIndex.has_value())
return true;
}
}*/
return false;
}
bool GlslWriter::HasExplicitLocation(const ShaderAst& shader)
bool GlslWriter::HasExplicitLocation(ShaderAst::StatementPtr& shader)
{
for (const auto& input : shader.GetInputs())
/*for (const auto& input : shader.GetInputs())
{
if (input.locationIndex.has_value())
return true;
@@ -704,7 +620,7 @@ namespace Nz
{
if (output.locationIndex.has_value())
return true;
}
}*/
return false;
}

View File

@@ -1,56 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void ShaderAst::AddCondition(std::string name)
{
auto& conditionEntry = m_conditions.emplace_back();
conditionEntry.name = std::move(name);
}
void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters, ShaderExpressionType returnType)
{
auto& functionEntry = m_functions.emplace_back();
functionEntry.name = std::move(name);
functionEntry.parameters = std::move(parameters);
functionEntry.returnType = returnType;
functionEntry.statement = std::move(statement);
}
void ShaderAst::AddInput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex)
{
auto& inputEntry = m_inputs.emplace_back();
inputEntry.name = std::move(name);
inputEntry.locationIndex = std::move(locationIndex);
inputEntry.type = std::move(type);
}
void ShaderAst::AddOutput(std::string name, ShaderExpressionType type, std::optional<std::size_t> locationIndex)
{
auto& outputEntry = m_outputs.emplace_back();
outputEntry.name = std::move(name);
outputEntry.locationIndex = std::move(locationIndex);
outputEntry.type = std::move(type);
}
void ShaderAst::AddStruct(std::string name, std::vector<StructMember> members)
{
auto& structEntry = m_structs.emplace_back();
structEntry.name = std::move(name);
structEntry.members = std::move(members);
}
void ShaderAst::AddUniform(std::string name, ShaderExpressionType type, std::optional<std::size_t> bindingIndex, std::optional<ShaderNodes::MemoryLayout> memoryLayout)
{
auto& uniformEntry = m_uniforms.emplace_back();
uniformEntry.bindingIndex = std::move(bindingIndex);
uniformEntry.memoryLayout = std::move(memoryLayout);
uniformEntry.name = std::move(name);
uniformEntry.type = std::move(type);
}
}

View File

@@ -6,240 +6,257 @@
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderNodes::StatementPtr ShaderAstCloner::Clone(const ShaderNodes::StatementPtr& statement)
ExpressionPtr AstCloner::Clone(ExpressionPtr& expr)
{
ShaderAstVisitor::Visit(statement);
expr->Visit(*this);
if (!m_expressionStack.empty() || !m_variableStack.empty() || m_statementStack.size() != 1)
throw std::runtime_error("An error occurred during clone");
assert(m_statementStack.empty() && m_expressionStack.size() == 1);
return PopExpression();
}
StatementPtr AstCloner::Clone(StatementPtr& statement)
{
statement->Visit(*this);
assert(m_expressionStack.empty() && m_statementStack.size() == 1);
return PopStatement();
}
ShaderNodes::ExpressionPtr ShaderAstCloner::CloneExpression(const ShaderNodes::ExpressionPtr& expr)
ExpressionPtr AstCloner::CloneExpression(ExpressionPtr& expr)
{
if (!expr)
return nullptr;
ShaderAstVisitor::Visit(expr);
expr->Visit(*this);
return PopExpression();
}
ShaderNodes::StatementPtr ShaderAstCloner::CloneStatement(const ShaderNodes::StatementPtr& statement)
StatementPtr AstCloner::CloneStatement(StatementPtr& statement)
{
if (!statement)
return nullptr;
ShaderAstVisitor::Visit(statement);
statement->Visit(*this);
return PopStatement();
}
ShaderNodes::VariablePtr ShaderAstCloner::CloneVariable(const ShaderNodes::VariablePtr& variable)
void AstCloner::Visit(AccessMemberExpression& node)
{
if (!variable)
return nullptr;
auto clone = std::make_unique<AccessMemberExpression>();
clone->memberIdentifiers = node.memberIdentifiers;
clone->structExpr = CloneExpression(node.structExpr);
ShaderVarVisitor::Visit(variable);
return PopVariable();
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::AccessMember& node)
void AstCloner::Visit(AssignExpression& node)
{
PushExpression(ShaderNodes::AccessMember::Build(CloneExpression(node.structExpr), node.memberIndices, node.exprType));
auto clone = std::make_unique<AssignExpression>();
clone->op = node.op;
clone->left = CloneExpression(node.left);
clone->right = CloneExpression(node.right);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::AssignOp& node)
void AstCloner::Visit(BinaryExpression& node)
{
PushExpression(ShaderNodes::AssignOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right)));
auto clone = std::make_unique<BinaryExpression>();
clone->op = node.op;
clone->left = CloneExpression(node.left);
clone->right = CloneExpression(node.right);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::BinaryOp& node)
void AstCloner::Visit(CastExpression& node)
{
PushExpression(ShaderNodes::BinaryOp::Build(node.op, CloneExpression(node.left), CloneExpression(node.right)));
}
auto clone = std::make_unique<CastExpression>();
clone->targetType = node.targetType;
void ShaderAstCloner::Visit(ShaderNodes::Branch& node)
{
std::vector<ShaderNodes::Branch::ConditionalStatement> condStatements;
condStatements.reserve(node.condStatements.size());
for (auto& cond : node.condStatements)
{
auto& condStatement = condStatements.emplace_back();
condStatement.condition = CloneExpression(cond.condition);
condStatement.statement = CloneStatement(cond.statement);
}
PushStatement(ShaderNodes::Branch::Build(std::move(condStatements), CloneStatement(node.elseStatement)));
}
void ShaderAstCloner::Visit(ShaderNodes::Cast& node)
{
std::size_t expressionCount = 0;
std::array<ShaderNodes::ExpressionPtr, 4> expressions;
for (auto& expr : node.expressions)
{
if (!expr)
break;
expressions[expressionCount] = CloneExpression(expr);
expressionCount++;
clone->expressions[expressionCount++] = CloneExpression(expr);
}
PushExpression(ShaderNodes::Cast::Build(node.exprType, expressions.data(), expressionCount));
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ConditionalExpression& node)
void AstCloner::Visit(ConditionalExpression& node)
{
PushExpression(ShaderNodes::ConditionalExpression::Build(node.conditionName, CloneExpression(node.truePath), CloneExpression(node.falsePath)));
auto clone = std::make_unique<ConditionalExpression>();
clone->conditionName = node.conditionName;
clone->falsePath = CloneExpression(node.falsePath);
clone->truePath = CloneExpression(node.truePath);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ConditionalStatement& node)
void AstCloner::Visit(ConstantExpression& node)
{
PushStatement(ShaderNodes::ConditionalStatement::Build(node.conditionName, CloneStatement(node.statement)));
auto clone = std::make_unique<ConstantExpression>();
clone->value = node.value;
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::Constant& node)
void AstCloner::Visit(IdentifierExpression& node)
{
PushExpression(ShaderNodes::Constant::Build(node.value));
auto clone = std::make_unique<IdentifierExpression>();
clone->identifier = node.identifier;
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::DeclareVariable& node)
void AstCloner::Visit(IntrinsicExpression& node)
{
PushStatement(ShaderNodes::DeclareVariable::Build(CloneVariable(node.variable), CloneExpression(node.expression)));
}
void ShaderAstCloner::Visit(ShaderNodes::Discard& /*node*/)
{
PushStatement(ShaderNodes::Discard::Build());
}
void ShaderAstCloner::Visit(ShaderNodes::ExpressionStatement& node)
{
PushStatement(ShaderNodes::ExpressionStatement::Build(CloneExpression(node.expression)));
}
void ShaderAstCloner::Visit(ShaderNodes::Identifier& node)
{
PushExpression(ShaderNodes::Identifier::Build(CloneVariable(node.var)));
}
void ShaderAstCloner::Visit(ShaderNodes::IntrinsicCall& node)
{
std::vector<ShaderNodes::ExpressionPtr> parameters;
parameters.reserve(node.parameters.size());
auto clone = std::make_unique<IntrinsicExpression>();
clone->intrinsic = node.intrinsic;
clone->parameters.reserve(node.parameters.size());
for (auto& parameter : node.parameters)
parameters.push_back(CloneExpression(parameter));
clone->parameters.push_back(CloneExpression(parameter));
PushExpression(ShaderNodes::IntrinsicCall::Build(node.intrinsic, std::move(parameters)));
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::NoOp& /*node*/)
void AstCloner::Visit(SwizzleExpression& node)
{
PushStatement(ShaderNodes::NoOp::Build());
auto clone = std::make_unique<SwizzleExpression>();
clone->componentCount = node.componentCount;
clone->components = node.components;
clone->expression = CloneExpression(node.expression);
PushExpression(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ReturnStatement& node)
void AstCloner::Visit(BranchStatement& node)
{
PushStatement(ShaderNodes::ReturnStatement::Build(CloneExpression(node.returnExpr)));
auto clone = std::make_unique<BranchStatement>();
clone->condStatements.reserve(node.condStatements.size());
for (auto& cond : node.condStatements)
{
auto& condStatement = clone->condStatements.emplace_back();
condStatement.condition = CloneExpression(cond.condition);
condStatement.statement = CloneStatement(cond.statement);
}
clone->elseStatement = CloneStatement(node.elseStatement);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::Sample2D& node)
void AstCloner::Visit(ConditionalStatement& node)
{
PushExpression(ShaderNodes::Sample2D::Build(CloneExpression(node.sampler), CloneExpression(node.coordinates)));
auto clone = std::make_unique<ConditionalStatement>();
clone->conditionName = node.conditionName;
clone->statement = CloneStatement(node.statement);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::StatementBlock& node)
void AstCloner::Visit(DeclareFunctionStatement& node)
{
std::vector<ShaderNodes::StatementPtr> statements;
statements.reserve(node.statements.size());
auto clone = std::make_unique<DeclareFunctionStatement>();
clone->name = node.name;
clone->parameters = node.parameters;
clone->returnType = node.returnType;
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
statements.push_back(CloneStatement(statement));
clone->statements.push_back(CloneStatement(statement));
PushStatement(ShaderNodes::StatementBlock::Build(std::move(statements)));
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::SwizzleOp& node)
void AstCloner::Visit(DeclareStructStatement& node)
{
PushExpression(ShaderNodes::SwizzleOp::Build(CloneExpression(node.expression), node.components.data(), node.componentCount));
auto clone = std::make_unique<DeclareStructStatement>();
clone->description = node.description;
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::BuiltinVariable& var)
void AstCloner::Visit(DeclareVariableStatement& node)
{
PushVariable(ShaderNodes::BuiltinVariable::Build(var.entry, var.type));
auto clone = std::make_unique<DeclareVariableStatement>();
clone->varName = node.varName;
clone->varType = node.varType;
clone->initialExpression = CloneExpression(node.initialExpression);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::InputVariable& var)
void AstCloner::Visit(DiscardStatement& /*node*/)
{
PushVariable(ShaderNodes::InputVariable::Build(var.name, var.type));
PushStatement(std::make_unique<DiscardStatement>());
}
void ShaderAstCloner::Visit(ShaderNodes::LocalVariable& var)
void AstCloner::Visit(ExpressionStatement& node)
{
PushVariable(ShaderNodes::LocalVariable::Build(var.name, var.type));
auto clone = std::make_unique<ExpressionStatement>();
clone->expression = CloneExpression(node.expression);
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::OutputVariable& var)
void AstCloner::Visit(MultiStatement& node)
{
PushVariable(ShaderNodes::OutputVariable::Build(var.name, var.type));
auto clone = std::make_unique<MultiStatement>();
clone->statements.reserve(node.statements.size());
for (auto& statement : node.statements)
clone->statements.push_back(CloneStatement(statement));
PushStatement(std::move(clone));
}
void ShaderAstCloner::Visit(ShaderNodes::ParameterVariable& var)
void AstCloner::Visit(NoOpStatement& /*node*/)
{
PushVariable(ShaderNodes::ParameterVariable::Build(var.name, var.type));
PushStatement(std::make_unique<NoOpStatement>());
}
void ShaderAstCloner::Visit(ShaderNodes::UniformVariable& var)
void AstCloner::Visit(ReturnStatement& node)
{
PushVariable(ShaderNodes::UniformVariable::Build(var.name, var.type));
auto clone = std::make_unique<ReturnStatement>();
clone->returnExpr = CloneExpression(node.returnExpr);
PushStatement(std::move(clone));
}
void ShaderAstCloner::PushExpression(ShaderNodes::ExpressionPtr expression)
void AstCloner::PushExpression(ExpressionPtr expression)
{
m_expressionStack.emplace_back(std::move(expression));
}
void ShaderAstCloner::PushStatement(ShaderNodes::StatementPtr statement)
void AstCloner::PushStatement(StatementPtr statement)
{
m_statementStack.emplace_back(std::move(statement));
}
void ShaderAstCloner::PushVariable(ShaderNodes::VariablePtr variable)
{
m_variableStack.emplace_back(std::move(variable));
}
ShaderNodes::ExpressionPtr ShaderAstCloner::PopExpression()
ExpressionPtr AstCloner::PopExpression()
{
assert(!m_expressionStack.empty());
ShaderNodes::ExpressionPtr expr = std::move(m_expressionStack.back());
ExpressionPtr expr = std::move(m_expressionStack.back());
m_expressionStack.pop_back();
return expr;
}
ShaderNodes::StatementPtr ShaderAstCloner::PopStatement()
StatementPtr AstCloner::PopStatement()
{
assert(!m_statementStack.empty());
ShaderNodes::StatementPtr expr = std::move(m_statementStack.back());
StatementPtr expr = std::move(m_statementStack.back());
m_statementStack.pop_back();
return expr;
}
ShaderNodes::VariablePtr ShaderAstCloner::PopVariable()
{
assert(!m_variableStack.empty());
ShaderNodes::VariablePtr var = std::move(m_variableStack.back());
m_variableStack.pop_back();
return var;
}
}

View File

@@ -0,0 +1,198 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/ShaderAstCache.hpp>
#include <optional>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
ShaderExpressionType ExpressionTypeVisitor::GetExpressionType(Expression& expression, AstCache* cache = nullptr)
{
m_cache = cache;
ShaderExpressionType type = GetExpressionTypeInternal(expression);
m_cache = nullptr;
return type;
}
ShaderExpressionType ExpressionTypeVisitor::GetExpressionTypeInternal(Expression& expression)
{
m_lastExpressionType.reset();
Visit(expression);
assert(m_lastExpressionType.has_value());
return std::move(*m_lastExpressionType);
}
void ExpressionTypeVisitor::Visit(Expression& expression)
{
if (m_cache)
{
auto it = m_cache->nodeExpressionType.find(&expression);
if (it != m_cache->nodeExpressionType.end())
{
m_lastExpressionType = it->second;
return;
}
}
expression.Visit(*this);
if (m_cache)
{
assert(m_lastExpressionType.has_value());
m_cache->nodeExpressionType.emplace(&expression, *m_lastExpressionType);
}
}
void ExpressionTypeVisitor::Visit(AccessMemberExpression& node)
{
throw std::runtime_error("unhandled accessmember expression");
}
void ExpressionTypeVisitor::Visit(AssignExpression& node)
{
Visit(*node.left);
}
void ExpressionTypeVisitor::Visit(BinaryExpression& node)
{
switch (node.op)
{
case BinaryType::Add:
case BinaryType::Subtract:
return Visit(*node.left);
case BinaryType::Divide:
case BinaryType::Multiply:
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.left);
assert(IsBasicType(leftExprType));
ShaderExpressionType rightExprType = GetExpressionTypeInternal(*node.right);
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
{
case BasicType::Boolean:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
m_lastExpressionType = std::move(leftExprType);
break;
case BasicType::Float1:
case BasicType::Int1:
case BasicType::Mat4x4:
case BasicType::UInt1:
m_lastExpressionType = std::move(rightExprType);
break;
case BasicType::Sampler2D:
case BasicType::Void:
break;
}
break;
}
case BinaryType::CompEq:
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompNe:
m_lastExpressionType = BasicType::Boolean;
break;
}
}
void ExpressionTypeVisitor::Visit(CastExpression& node)
{
m_lastExpressionType = node.targetType;
}
void ExpressionTypeVisitor::Visit(ConditionalExpression& node)
{
ShaderExpressionType leftExprType = GetExpressionTypeInternal(*node.truePath);
assert(leftExprType == GetExpressionTypeInternal(*node.falsePath));
m_lastExpressionType = std::move(leftExprType);
}
void ExpressionTypeVisitor::Visit(ConstantExpression& node)
{
m_lastExpressionType = std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return BasicType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return BasicType::Float1;
else if constexpr (std::is_same_v<T, Int32>)
return BasicType::Int1;
else if constexpr (std::is_same_v<T, UInt32>)
return BasicType::Int1;
else if constexpr (std::is_same_v<T, Vector2f>)
return BasicType::Float2;
else if constexpr (std::is_same_v<T, Vector3f>)
return BasicType::Float3;
else if constexpr (std::is_same_v<T, Vector4f>)
return BasicType::Float4;
else if constexpr (std::is_same_v<T, Vector2i32>)
return BasicType::Int2;
else if constexpr (std::is_same_v<T, Vector3i32>)
return BasicType::Int3;
else if constexpr (std::is_same_v<T, Vector4i32>)
return BasicType::Int4;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, node.value);
}
void ExpressionTypeVisitor::Visit(IdentifierExpression& node)
{
auto scopeIt = m_cache->scopeIdByNode.find(&node);
if (scopeIt == m_cache->scopeIdByNode.end())
throw std::runtime_error("internal error");
const AstCache::Identifier* identifier = m_cache->FindIdentifier(scopeIt->second, node.identifier);
if (!identifier || !std::holds_alternative<AstCache::Variable>(identifier->value))
throw std::runtime_error("internal error");
m_lastExpressionType = std::get<AstCache::Variable>(identifier->value).type;
}
void ExpressionTypeVisitor::Visit(IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case IntrinsicType::CrossProduct:
Visit(*node.parameters.front());
break;
case IntrinsicType::DotProduct:
m_lastExpressionType = BasicType::Float1;
break;
}
}
void ExpressionTypeVisitor::Visit(SwizzleExpression& node)
{
const ShaderExpressionType& exprType = GetExpressionTypeInternal(*node.expression);
assert(IsBasicType(exprType));
m_lastExpressionType = static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + node.componentCount - 1);
}
}

View File

@@ -2,15 +2,10 @@
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderAstVisitor::~ShaderAstVisitor() = default;
void ShaderAstVisitor::Visit(const ShaderNodes::NodePtr& node)
{
node->Visit(*this);
}
AstExpressionVisitor::~AstExpressionVisitor() = default;
}

View File

@@ -0,0 +1,15 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstExpressionVisitorExcept.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_EXPRESSION(Node) void ExpressionVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
}

View File

@@ -3,16 +3,22 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstOptimizer.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <cassert>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
namespace
{
template<typename T, typename U>
std::unique_ptr<T> static_unique_pointer_cast(std::unique_ptr<U>&& ptr)
{
return std::unique_ptr<T>(static_cast<T*>(ptr.release()));
}
template <typename T>
struct is_complete_helper
{
@@ -29,14 +35,14 @@ namespace Nz
inline constexpr bool is_complete_v = is_complete<T>::value;
template<ShaderNodes::BinaryType Type, typename T1, typename T2>
template<BinaryType Type, typename T1, typename T2>
struct PropagateConstantType;
// CompEq
template<typename T1, typename T2>
struct CompEqBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs == rhs);
}
@@ -46,7 +52,7 @@ namespace Nz
struct CompEq;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompEq, T1, T2>
struct PropagateConstantType<BinaryType::CompEq, T1, T2>
{
using Op = typename CompEq<T1, T2>;
};
@@ -55,7 +61,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompGeBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs >= rhs);
}
@@ -65,7 +71,7 @@ namespace Nz
struct CompGe;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompGe, T1, T2>
struct PropagateConstantType<BinaryType::CompGe, T1, T2>
{
using Op = typename CompGe<T1, T2>;
};
@@ -74,7 +80,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompGtBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs > rhs);
}
@@ -84,7 +90,7 @@ namespace Nz
struct CompGt;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompGt, T1, T2>
struct PropagateConstantType<BinaryType::CompGt, T1, T2>
{
using Op = typename CompGt<T1, T2>;
};
@@ -93,7 +99,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompLeBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs <= rhs);
}
@@ -103,7 +109,7 @@ namespace Nz
struct CompLe;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompLe, T1, T2>
struct PropagateConstantType<BinaryType::CompLe, T1, T2>
{
using Op = typename CompLe<T1, T2>;
};
@@ -112,7 +118,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompLtBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs < rhs);
}
@@ -122,7 +128,7 @@ namespace Nz
struct CompLt;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompLt, T1, T2>
struct PropagateConstantType<BinaryType::CompLt, T1, T2>
{
using Op = typename CompLe<T1, T2>;
};
@@ -131,7 +137,7 @@ namespace Nz
template<typename T1, typename T2>
struct CompNeBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs != rhs);
}
@@ -141,7 +147,7 @@ namespace Nz
struct CompNe;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::CompNe, T1, T2>
struct PropagateConstantType<BinaryType::CompNe, T1, T2>
{
using Op = typename CompNe<T1, T2>;
};
@@ -150,7 +156,7 @@ namespace Nz
template<typename T1, typename T2>
struct AdditionBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs + rhs);
}
@@ -160,7 +166,7 @@ namespace Nz
struct Addition;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Add, T1, T2>
struct PropagateConstantType<BinaryType::Add, T1, T2>
{
using Op = typename Addition<T1, T2>;
};
@@ -169,7 +175,7 @@ namespace Nz
template<typename T1, typename T2>
struct DivisionBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs / rhs);
}
@@ -179,7 +185,7 @@ namespace Nz
struct Division;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Divide, T1, T2>
struct PropagateConstantType<BinaryType::Divide, T1, T2>
{
using Op = typename Division<T1, T2>;
};
@@ -188,7 +194,7 @@ namespace Nz
template<typename T1, typename T2>
struct MultiplicationBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs * rhs);
}
@@ -198,7 +204,7 @@ namespace Nz
struct Multiplication;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Multiply, T1, T2>
struct PropagateConstantType<BinaryType::Multiply, T1, T2>
{
using Op = typename Multiplication<T1, T2>;
};
@@ -207,7 +213,7 @@ namespace Nz
template<typename T1, typename T2>
struct SubtractionBase
{
ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs)
ExpressionPtr operator()(const T1& lhs, const T2& rhs)
{
return ShaderBuilder::Constant(lhs - rhs);
}
@@ -217,7 +223,7 @@ namespace Nz
struct Subtraction;
template<typename T1, typename T2>
struct PropagateConstantType<ShaderNodes::BinaryType::Subtract, T1, T2>
struct PropagateConstantType<BinaryType::Subtract, T1, T2>
{
using Op = typename Subtraction<T1, T2>;
};
@@ -375,92 +381,89 @@ namespace Nz
#undef EnableOptimisation
}
ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement)
StatementPtr AstOptimizer::Optimise(StatementPtr& statement)
{
m_shaderAst = nullptr;
return CloneStatement(statement);
}
ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions)
StatementPtr AstOptimizer::Optimise(StatementPtr& statement, UInt64 enabledConditions)
{
m_shaderAst = &shader;
m_enabledConditions = enabledConditions;
return CloneStatement(statement);
}
void ShaderAstOptimizer::Visit(ShaderNodes::BinaryOp& node)
void AstOptimizer::Visit(BinaryExpression& node)
{
auto lhs = CloneExpression(node.left);
auto rhs = CloneExpression(node.right);
if (lhs->GetType() == ShaderNodes::NodeType::Constant && rhs->GetType() == ShaderNodes::NodeType::Constant)
if (lhs->GetType() == NodeType::ConstantExpression && rhs->GetType() == NodeType::ConstantExpression)
{
auto lhsConstant = std::static_pointer_cast<ShaderNodes::Constant>(lhs);
auto rhsConstant = std::static_pointer_cast<ShaderNodes::Constant>(rhs);
auto lhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(lhs));
auto rhsConstant = static_unique_pointer_cast<ConstantExpression>(std::move(rhs));
switch (node.op)
{
case ShaderNodes::BinaryType::Add:
return PropagateConstant<ShaderNodes::BinaryType::Add>(lhsConstant, rhsConstant);
case BinaryType::Add:
return PropagateConstant<BinaryType::Add>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::Subtract:
return PropagateConstant<ShaderNodes::BinaryType::Subtract>(lhsConstant, rhsConstant);
case BinaryType::Subtract:
return PropagateConstant<BinaryType::Subtract>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::Multiply:
return PropagateConstant<ShaderNodes::BinaryType::Multiply>(lhsConstant, rhsConstant);
case BinaryType::Multiply:
return PropagateConstant<BinaryType::Multiply>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::Divide:
return PropagateConstant<ShaderNodes::BinaryType::Divide>(lhsConstant, rhsConstant);
case BinaryType::Divide:
return PropagateConstant<BinaryType::Divide>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompEq:
return PropagateConstant<ShaderNodes::BinaryType::CompEq>(lhsConstant, rhsConstant);
case BinaryType::CompEq:
return PropagateConstant<BinaryType::CompEq>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompGe:
return PropagateConstant<ShaderNodes::BinaryType::CompGe>(lhsConstant, rhsConstant);
case BinaryType::CompGe:
return PropagateConstant<BinaryType::CompGe>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompGt:
return PropagateConstant<ShaderNodes::BinaryType::CompGt>(lhsConstant, rhsConstant);
case BinaryType::CompGt:
return PropagateConstant<BinaryType::CompGt>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompLe:
return PropagateConstant<ShaderNodes::BinaryType::CompLe>(lhsConstant, rhsConstant);
case BinaryType::CompLe:
return PropagateConstant<BinaryType::CompLe>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompLt:
return PropagateConstant<ShaderNodes::BinaryType::CompLt>(lhsConstant, rhsConstant);
case BinaryType::CompLt:
return PropagateConstant<BinaryType::CompLt>(std::move(lhsConstant), std::move(rhsConstant));
case ShaderNodes::BinaryType::CompNe:
return PropagateConstant<ShaderNodes::BinaryType::CompNe>(lhsConstant, rhsConstant);
case BinaryType::CompNe:
return PropagateConstant<BinaryType::CompNe>(std::move(lhsConstant), std::move(rhsConstant));
}
}
ShaderAstCloner::Visit(node);
AstCloner::Visit(node);
}
void ShaderAstOptimizer::Visit(ShaderNodes::Branch& node)
void AstOptimizer::Visit(BranchStatement& node)
{
std::vector<ShaderNodes::Branch::ConditionalStatement> statements;
ShaderNodes::StatementPtr elseStatement;
std::vector<BranchStatement::ConditionalStatement> statements;
StatementPtr elseStatement;
for (auto& condStatement : node.condStatements)
{
auto cond = CloneExpression(condStatement.condition);
if (cond->GetType() == ShaderNodes::NodeType::Constant)
if (cond->GetType() == NodeType::ConstantExpression)
{
auto constant = std::static_pointer_cast<ShaderNodes::Constant>(cond);
auto& constant = static_cast<ConstantExpression&>(*cond);
assert(IsBasicType(cond->GetExpressionType()));
assert(std::get<ShaderNodes::BasicType>(cond->GetExpressionType()) == ShaderNodes::BasicType::Boolean);
assert(IsBasicType(GetExpressionType(constant)));
assert(std::get<BasicType>(GetExpressionType(constant)) == BasicType::Boolean);
bool cValue = std::get<bool>(constant->value);
bool cValue = std::get<bool>(constant.value);
if (!cValue)
continue;
if (statements.empty())
{
// First condition is true, dismiss the branch
Visit(condStatement.statement);
condStatement.statement->Visit(*this);
return;
}
else
@@ -482,47 +485,54 @@ namespace Nz
{
// All conditions have been removed, replace by else statement or no-op
if (node.elseStatement)
return Visit(node.elseStatement);
{
node.elseStatement->Visit(*this);
return;
}
else
return PushStatement(ShaderNodes::NoOp::Build());
return PushStatement(ShaderBuilder::NoOp());
}
if (!elseStatement)
elseStatement = CloneStatement(node.elseStatement);
PushStatement(ShaderNodes::Branch::Build(std::move(statements), std::move(elseStatement)));
PushStatement(ShaderBuilder::Branch(std::move(statements), std::move(elseStatement)));
}
void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalExpression& node)
void AstOptimizer::Visit(ConditionalExpression& node)
{
if (!m_shaderAst)
return AstCloner::Visit(node);
/*if (!m_shaderAst)
return ShaderAstCloner::Visit(node);
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.truePath);
else
Visit(node.falsePath);
Visit(node.falsePath);*/
}
void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalStatement& node)
void AstOptimizer::Visit(ConditionalStatement& node)
{
if (!m_shaderAst)
return AstCloner::Visit(node);
/*if (!m_shaderAst)
return ShaderAstCloner::Visit(node);
std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
assert(conditionIndex != InvalidCondition);
if (TestBit<Nz::UInt64>(m_enabledConditions, conditionIndex))
Visit(node.statement);
Visit(node.statement);*/
}
template<ShaderNodes::BinaryType Type>
void ShaderAstOptimizer::PropagateConstant(const std::shared_ptr<ShaderNodes::Constant>& lhs, const std::shared_ptr<ShaderNodes::Constant>& rhs)
template<BinaryType Type>
void AstOptimizer::PropagateConstant(std::unique_ptr<ConstantExpression>&& lhs, std::unique_ptr<ConstantExpression>&& rhs)
{
ShaderNodes::ExpressionPtr optimized;
ExpressionPtr optimized;
std::visit([&](auto&& arg1)
{
using T1 = std::decay_t<decltype(arg1)>;
@@ -543,8 +553,8 @@ namespace Nz
}, lhs->value);
if (optimized)
PushExpression(optimized);
PushExpression(std::move(optimized));
else
PushExpression(ShaderNodes::BinaryOp::Build(Type, lhs, rhs));
PushExpression(ShaderBuilder::Binary(Type, std::move(lhs), std::move(rhs)));
}
}

View File

@@ -5,116 +5,121 @@
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AccessMember& node)
void AstRecursiveVisitor::Visit(AccessMemberExpression& node)
{
Visit(node.structExpr);
node.structExpr->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::AssignOp& node)
void AstRecursiveVisitor::Visit(AssignExpression& node)
{
Visit(node.left);
Visit(node.right);
node.left->Visit(*this);
node.right->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::BinaryOp& node)
void AstRecursiveVisitor::Visit(BinaryExpression& node)
{
Visit(node.left);
Visit(node.right);
node.left->Visit(*this);
node.right->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Branch& node)
{
for (auto& cond : node.condStatements)
{
Visit(cond.condition);
Visit(cond.statement);
}
if (node.elseStatement)
Visit(node.elseStatement);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Cast& node)
void AstRecursiveVisitor::Visit(CastExpression& node)
{
for (auto& expr : node.expressions)
{
if (!expr)
break;
Visit(expr);
expr->Visit(*this);
}
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalExpression& node)
void AstRecursiveVisitor::Visit(ConditionalExpression& node)
{
Visit(node.truePath);
Visit(node.falsePath);
node.truePath->Visit(*this);
node.falsePath->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ConditionalStatement& node)
{
Visit(node.statement);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Constant& /*node*/)
void AstRecursiveVisitor::Visit(ConstantExpression& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::DeclareVariable& node)
{
if (node.expression)
Visit(node.expression);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Discard& /*node*/)
void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ExpressionStatement& node)
{
Visit(node.expression);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Identifier& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::IntrinsicCall& node)
void AstRecursiveVisitor::Visit(IntrinsicExpression& node)
{
for (auto& param : node.parameters)
Visit(param);
param->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::NoOp& /*node*/)
void AstRecursiveVisitor::Visit(SwizzleExpression& node)
{
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(BranchStatement& node)
{
for (auto& cond : node.condStatements)
{
cond.condition->Visit(*this);
cond.statement->Visit(*this);
}
if (node.elseStatement)
node.elseStatement->Visit(*this);
}
void AstRecursiveVisitor::Visit(ConditionalStatement& node)
{
node.statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareFunctionStatement& node)
{
for (auto& statement : node.statements)
statement->Visit(*this);
}
void AstRecursiveVisitor::Visit(DeclareStructStatement& /*node*/)
{
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ReturnStatement& node)
void AstRecursiveVisitor::Visit(DeclareVariableStatement& node)
{
if (node.returnExpr)
Visit(node.returnExpr);
if (node.initialExpression)
node.initialExpression->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Sample2D& node)
void AstRecursiveVisitor::Visit(DiscardStatement& /*node*/)
{
Visit(node.sampler);
Visit(node.coordinates);
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::StatementBlock& node)
void AstRecursiveVisitor::Visit(ExpressionStatement& node)
{
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(MultiStatement& node)
{
for (auto& statement : node.statements)
Visit(statement);
statement->Visit(*this);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::SwizzleOp& node)
void AstRecursiveVisitor::Visit(NoOpStatement& /*node*/)
{
Visit(node.expression);
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(ReturnStatement& node)
{
if (node.returnExpr)
node.returnExpr->Visit(*this);
}
}

View File

@@ -3,221 +3,74 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
namespace
{
constexpr UInt32 s_magicNumber = 0x4E534852;
constexpr UInt32 s_currentVersion = 1;
class ShaderSerializerVisitor : public ShaderAstVisitor, public ShaderVarVisitor
class ShaderSerializerVisitor : public AstExpressionVisitor, public AstStatementVisitor
{
public:
ShaderSerializerVisitor(ShaderAstSerializerBase& serializer) :
ShaderSerializerVisitor(AstSerializerBase& serializer) :
m_serializer(serializer)
{
}
void Visit(ShaderNodes::AccessMember& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::AssignOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::BinaryOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Branch& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Cast& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ConditionalExpression& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ConditionalStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Constant& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::DeclareVariable& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Discard& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ExpressionStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Identifier& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::IntrinsicCall& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::NoOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::ReturnStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Sample2D& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::StatementBlock& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::SwizzleOp& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::BuiltinVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::InputVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::LocalVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::OutputVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::ParameterVariable& var) override
{
Serialize(var);
}
void Visit(ShaderNodes::UniformVariable& var) override
{
Serialize(var);
#define NAZARA_SHADERAST_NODE(Node) void Visit(Node& node) override \
{ \
m_serializer.Serialize(node); \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
private:
template<typename T>
void Serialize(const T& node)
{
// I know const_cast is evil but I don't have a better solution here (it's not used to write)
m_serializer.Serialize(const_cast<T&>(node));
}
ShaderAstSerializerBase& m_serializer;
AstSerializerBase& m_serializer;
};
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::AccessMember& node)
void AstSerializerBase::Serialize(AccessMemberExpression& node)
{
Node(node.structExpr);
Type(node.exprType);
Container(node.memberIndices);
for (std::size_t& index : node.memberIndices)
SizeT(index);
Container(node.memberIdentifiers);
for (std::string& identifier : node.memberIdentifiers)
Value(identifier);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::AssignOp& node)
void AstSerializerBase::Serialize(AssignExpression& node)
{
Enum(node.op);
Node(node.left);
Node(node.right);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::BinaryOp& node)
void AstSerializerBase::Serialize(BinaryExpression& node)
{
Enum(node.op);
Node(node.left);
Node(node.right);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Branch& node)
void AstSerializerBase::Serialize(CastExpression& node)
{
Container(node.condStatements);
for (auto& condStatement : node.condStatements)
{
Node(condStatement.condition);
Node(condStatement.statement);
}
Node(node.elseStatement);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node)
{
Enum(node.entry);
Type(node.type);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Cast& node)
{
Enum(node.exprType);
Enum(node.targetType);
for (auto& expr : node.expressions)
Node(expr);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalExpression& node)
void AstSerializerBase::Serialize(ConditionalExpression& node)
{
Value(node.conditionName);
Node(node.truePath);
Node(node.falsePath);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ConditionalStatement& node)
{
Value(node.conditionName);
Node(node.statement);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Constant& node)
void AstSerializerBase::Serialize(ConstantExpression& node)
{
UInt32 typeIndex;
if (IsWriting())
@@ -251,28 +104,19 @@ namespace Nz
}
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::DeclareVariable& node)
void AstSerializerBase::Serialize(DeclareVariableStatement& node)
{
Variable(node.variable);
Node(node.expression);
Value(node.varName);
Type(node.varType);
Node(node.initialExpression);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Discard& /*node*/)
void AstSerializerBase::Serialize(IdentifierExpression& node)
{
/* Nothing to do */
Value(node.identifier);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ExpressionStatement& node)
{
Node(node.expression);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Identifier& node)
{
Variable(node.var);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::IntrinsicCall& node)
void AstSerializerBase::Serialize(IntrinsicExpression& node)
{
Enum(node.intrinsic);
Container(node.parameters);
@@ -280,36 +124,7 @@ namespace Nz
Node(param);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::NamedVariable& node)
{
Value(node.name);
Type(node.type);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::NoOp& /*node*/)
{
/* Nothing to do */
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ReturnStatement& node)
{
Node(node.returnExpr);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Sample2D& node)
{
Node(node.sampler);
Node(node.coordinates);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::StatementBlock& node)
{
Container(node.statements);
for (auto& statement : node.statements)
Node(statement);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::SwizzleOp& node)
void AstSerializerBase::Serialize(SwizzleExpression& node)
{
SizeT(node.componentCount);
Node(node.expression);
@@ -319,100 +134,85 @@ namespace Nz
}
void ShaderAstSerializer::Serialize(const ShaderAst& shader)
void AstSerializerBase::Serialize(BranchStatement& node)
{
Container(node.condStatements);
for (auto& condStatement : node.condStatements)
{
Node(condStatement.condition);
Node(condStatement.statement);
}
Node(node.elseStatement);
}
void AstSerializerBase::Serialize(ConditionalStatement& node)
{
Value(node.conditionName);
Node(node.statement);
}
void AstSerializerBase::Serialize(DeclareFunctionStatement& node)
{
Value(node.name);
Type(node.returnType);
Container(node.parameters);
for (auto& parameter : node.parameters)
{
Value(parameter.name);
Type(parameter.type);
}
Container(node.statements);
for (auto& statement : node.statements)
Node(statement);
}
void AstSerializerBase::Serialize(DeclareStructStatement& node)
{
Value(node.description.name);
Container(node.description.members);
for (auto& member : node.description.members)
{
Value(member.name);
Type(member.type);
}
}
void AstSerializerBase::Serialize(DiscardStatement& /*node*/)
{
/* Nothing to do */
}
void AstSerializerBase::Serialize(ExpressionStatement& node)
{
Node(node.expression);
}
void AstSerializerBase::Serialize(MultiStatement& node)
{
Container(node.statements);
for (auto& statement : node.statements)
Node(statement);
}
void AstSerializerBase::Serialize(NoOpStatement& /*node*/)
{
/* Nothing to do */
}
void AstSerializerBase::Serialize(ReturnStatement& node)
{
Node(node.returnExpr);
}
void ShaderAstSerializer::Serialize(StatementPtr& shader)
{
m_stream << s_magicNumber << s_currentVersion;
m_stream << UInt32(shader.GetStage());
auto SerializeType = [&](const ShaderExpressionType& type)
{
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
{
m_stream << UInt8(0);
m_stream << UInt32(arg);
}
else if constexpr (std::is_same_v<T, std::string>)
{
m_stream << UInt8(1);
m_stream << arg;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, type);
};
auto SerializeInputOutput = [&](auto& inout)
{
m_stream << UInt32(inout.size());
for (const auto& data : inout)
{
m_stream << data.name;
SerializeType(data.type);
m_stream << data.locationIndex.has_value();
if (data.locationIndex)
m_stream << UInt32(data.locationIndex.value());
}
};
// Conditions
m_stream << UInt32(shader.GetConditionCount());
for (const auto& cond : shader.GetConditions())
m_stream << cond.name;
// Structs
m_stream << UInt32(shader.GetStructCount());
for (const auto& s : shader.GetStructs())
{
m_stream << s.name;
m_stream << UInt32(s.members.size());
for (const auto& member : s.members)
{
m_stream << member.name;
SerializeType(member.type);
}
}
// Inputs / Outputs
SerializeInputOutput(shader.GetInputs());
SerializeInputOutput(shader.GetOutputs());
// Uniforms
m_stream << UInt32(shader.GetUniformCount());
for (const auto& uniform : shader.GetUniforms())
{
m_stream << uniform.name;
SerializeType(uniform.type);
m_stream << uniform.bindingIndex.has_value();
if (uniform.bindingIndex)
m_stream << UInt32(uniform.bindingIndex.value());
m_stream << uniform.memoryLayout.has_value();
if (uniform.memoryLayout)
m_stream << UInt32(uniform.memoryLayout.value());
}
// Functions
m_stream << UInt32(shader.GetFunctionCount());
for (const auto& func : shader.GetFunctions())
{
m_stream << func.name;
SerializeType(func.returnType);
m_stream << UInt32(func.parameters.size());
for (const auto& param : func.parameters)
{
m_stream << param.name;
SerializeType(param.type);
}
Node(func.statement);
}
Node(shader);
m_stream.FlushBits();
}
@@ -422,9 +222,21 @@ namespace Nz
return true;
}
void ShaderAstSerializer::Node(ShaderNodes::NodePtr& node)
void ShaderAstSerializer::Node(ExpressionPtr& node)
{
ShaderNodes::NodeType nodeType = (node) ? node->GetType() : ShaderNodes::NodeType::None;
NodeType nodeType = (node) ? node->GetType() : NodeType::None;
m_stream << static_cast<Int32>(nodeType);
if (node)
{
ShaderSerializerVisitor visitor(*this);
node->Visit(visitor);
}
}
void ShaderAstSerializer::Node(StatementPtr& node)
{
NodeType nodeType = (node) ? node->GetType() : NodeType::None;
m_stream << static_cast<Int32>(nodeType);
if (node)
@@ -439,7 +251,7 @@ namespace Nz
std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
if constexpr (std::is_same_v<T, BasicType>)
{
m_stream << UInt8(0);
m_stream << UInt32(arg);
@@ -454,11 +266,6 @@ namespace Nz
}, type);
}
void ShaderAstSerializer::Node(const ShaderNodes::NodePtr& node)
{
Node(const_cast<ShaderNodes::NodePtr&>(node)); //< Yes const_cast is ugly but it won't be used for writing
}
void ShaderAstSerializer::Value(bool& val)
{
m_stream << val;
@@ -529,19 +336,7 @@ namespace Nz
m_stream << val;
}
void ShaderAstSerializer::Variable(ShaderNodes::VariablePtr& var)
{
ShaderNodes::VariableType nodeType = (var) ? var->GetType() : ShaderNodes::VariableType::None;
m_stream << static_cast<Int32>(nodeType);
if (var)
{
ShaderSerializerVisitor visitor(*this);
var->Visit(visitor);
}
}
ShaderAst ShaderAstUnserializer::Unserialize()
StatementPtr ShaderAstUnserializer::Unserialize()
{
UInt32 magicNumber;
UInt32 version;
@@ -553,122 +348,13 @@ namespace Nz
if (version > s_currentVersion)
throw std::runtime_error("unsupported version");
UInt32 shaderStage;
m_stream >> shaderStage;
StatementPtr node;
ShaderAst shader(static_cast<ShaderStageType>(shaderStage));
Node(node);
if (!node)
throw std::runtime_error("functions can only have statements");
// Conditions
UInt32 conditionCount;
m_stream >> conditionCount;
for (UInt32 i = 0; i < conditionCount; ++i)
{
std::string conditionName;
Value(conditionName);
shader.AddCondition(std::move(conditionName));
}
// Structs
UInt32 structCount;
m_stream >> structCount;
for (UInt32 i = 0; i < structCount; ++i)
{
std::string structName;
std::vector<ShaderAst::StructMember> members;
Value(structName);
Container(members);
for (auto& member : members)
{
Value(member.name);
Type(member.type);
}
shader.AddStruct(std::move(structName), std::move(members));
}
// Inputs
UInt32 inputCount;
m_stream >> inputCount;
for (UInt32 i = 0; i < inputCount; ++i)
{
std::string inputName;
ShaderExpressionType inputType;
std::optional<std::size_t> location;
Value(inputName);
Type(inputType);
OptVal(location);
shader.AddInput(std::move(inputName), std::move(inputType), location);
}
// Outputs
UInt32 outputCount;
m_stream >> outputCount;
for (UInt32 i = 0; i < outputCount; ++i)
{
std::string outputName;
ShaderExpressionType outputType;
std::optional<std::size_t> location;
Value(outputName);
Type(outputType);
OptVal(location);
shader.AddOutput(std::move(outputName), std::move(outputType), location);
}
// Uniforms
UInt32 uniformCount;
m_stream >> uniformCount;
for (UInt32 i = 0; i < uniformCount; ++i)
{
std::string name;
ShaderExpressionType type;
std::optional<std::size_t> binding;
std::optional<ShaderNodes::MemoryLayout> memLayout;
Value(name);
Type(type);
OptVal(binding);
OptEnum(memLayout);
shader.AddUniform(std::move(name), std::move(type), std::move(binding), std::move(memLayout));
}
// Functions
UInt32 funcCount;
m_stream >> funcCount;
for (UInt32 i = 0; i < funcCount; ++i)
{
std::string name;
ShaderExpressionType retType;
std::vector<ShaderAst::FunctionParameter> parameters;
Value(name);
Type(retType);
Container(parameters);
for (auto& param : parameters)
{
Value(param.name);
Type(param.type);
}
ShaderNodes::NodePtr node;
Node(node);
if (!node || !node->IsStatement())
throw std::runtime_error("functions can only have statements");
ShaderNodes::StatementPtr statement = std::static_pointer_cast<ShaderNodes::Statement>(node);
shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), std::move(retType));
}
return shader;
return node;
}
bool ShaderAstUnserializer::IsWriting() const
@@ -676,41 +362,50 @@ namespace Nz
return false;
}
void ShaderAstUnserializer::Node(ShaderNodes::NodePtr& node)
void ShaderAstUnserializer::Node(ExpressionPtr& node)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
if (nodeTypeInt < static_cast<Int32>(ShaderNodes::NodeType::None) || nodeTypeInt > static_cast<Int32>(ShaderNodes::NodeType::Max))
if (nodeTypeInt < static_cast<Int32>(NodeType::None) || nodeTypeInt > static_cast<Int32>(NodeType::Max))
throw std::runtime_error("invalid node type");
ShaderNodes::NodeType nodeType = static_cast<ShaderNodes::NodeType>(nodeTypeInt);
#define HandleType(Type) case ShaderNodes::NodeType:: Type : node = std::make_shared<ShaderNodes:: Type>(); break
NodeType nodeType = static_cast<NodeType>(nodeTypeInt);
switch (nodeType)
{
case ShaderNodes::NodeType::None: break;
case NodeType::None: break;
HandleType(AccessMember);
HandleType(AssignOp);
HandleType(BinaryOp);
HandleType(Branch);
HandleType(Cast);
HandleType(Constant);
HandleType(ConditionalExpression);
HandleType(ConditionalStatement);
HandleType(DeclareVariable);
HandleType(Discard);
HandleType(ExpressionStatement);
HandleType(Identifier);
HandleType(IntrinsicCall);
HandleType(NoOp);
HandleType(ReturnStatement);
HandleType(Sample2D);
HandleType(SwizzleOp);
HandleType(StatementBlock);
#define NAZARA_SHADERAST_EXPRESSION(Node) case NodeType:: Node : node = std::make_unique<Node>(); break;
#include <Nazara/Shader/ShaderAstNodes.hpp>
default: throw std::runtime_error("unexpected node type");
}
if (node)
{
ShaderSerializerVisitor visitor(*this);
node->Visit(visitor);
}
}
void ShaderAstUnserializer::Node(StatementPtr& node)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
if (nodeTypeInt < static_cast<Int32>(NodeType::None) || nodeTypeInt > static_cast<Int32>(NodeType::Max))
throw std::runtime_error("invalid node type");
NodeType nodeType = static_cast<NodeType>(nodeTypeInt);
switch (nodeType)
{
case NodeType::None: break;
#define NAZARA_SHADERAST_STATEMENT(Node) case NodeType:: Node : node = std::make_unique<Node>(); break;
#include <Nazara/Shader/ShaderAstNodes.hpp>
default: throw std::runtime_error("unexpected node type");
}
#undef HandleType
if (node)
{
@@ -728,7 +423,7 @@ namespace Nz
{
case 0: //< Primitive
{
ShaderNodes::BasicType exprType;
BasicType exprType;
Enum(exprType);
type = exprType;
@@ -819,36 +514,8 @@ namespace Nz
m_stream >> val;
}
void ShaderAstUnserializer::Variable(ShaderNodes::VariablePtr& var)
{
Int32 nodeTypeInt;
m_stream >> nodeTypeInt;
ShaderNodes::VariableType nodeType = static_cast<ShaderNodes:: VariableType>(nodeTypeInt);
#define HandleType(Type) case ShaderNodes::VariableType:: Type : var = std::make_shared<ShaderNodes::Type>(); break
switch (nodeType)
{
case ShaderNodes::VariableType::None: break;
HandleType(BuiltinVariable);
HandleType(InputVariable);
HandleType(LocalVariable);
HandleType(ParameterVariable);
HandleType(OutputVariable);
HandleType(UniformVariable);
}
#undef HandleType
if (var)
{
ShaderSerializerVisitor visitor(*this);
var->Visit(visitor);
}
}
ByteArray SerializeShader(const ShaderAst& shader)
ByteArray SerializeShader(StatementPtr& shader)
{
ByteArray byteArray;
ByteStream stream(&byteArray, OpenModeFlags(OpenMode_WriteOnly));
@@ -859,7 +526,7 @@ namespace Nz
return byteArray;
}
ShaderAst UnserializeShader(ByteStream& stream)
StatementPtr UnserializeShader(ByteStream& stream)
{
ShaderAstUnserializer unserializer(stream);
return unserializer.Unserialize();

View File

@@ -2,15 +2,10 @@
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderVarVisitor::~ShaderVarVisitor() = default;
void ShaderVarVisitor::Visit(const ShaderNodes::VariablePtr& node)
{
node->Visit(*this);
}
AstStatementVisitor::~AstStatementVisitor() = default;
}

View File

@@ -0,0 +1,15 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstStatementVisitorExcept.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
{
#define NAZARA_SHADERAST_STATEMENT(Node) void StatementVisitorExcept::Visit(ShaderAst::Node& /*node*/) \
{ \
throw std::runtime_error("unexpected " #Node " node"); \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
}

View File

@@ -5,69 +5,65 @@
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
ShaderNodes::ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression)
ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(Expression& expression)
{
Visit(expression);
expression.Visit(*this);
return m_expressionCategory;
}
void ShaderAstValueCategory::Visit(ShaderNodes::AccessMember& node)
void ShaderAstValueCategory::Visit(AccessMemberExpression& node)
{
Visit(node.structExpr);
node.structExpr->Visit(*this);
}
void ShaderAstValueCategory::Visit(ShaderNodes::AssignOp& node)
void ShaderAstValueCategory::Visit(AssignExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::BinaryOp& node)
void ShaderAstValueCategory::Visit(BinaryExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Cast& node)
void ShaderAstValueCategory::Visit(CastExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::ConditionalExpression& node)
void ShaderAstValueCategory::Visit(ConditionalExpression& node)
{
Visit(node.truePath);
ShaderNodes::ExpressionCategory trueExprCategory = m_expressionCategory;
Visit(node.falsePath);
ShaderNodes::ExpressionCategory falseExprCategory = m_expressionCategory;
node.truePath->Visit(*this);
ExpressionCategory trueExprCategory = m_expressionCategory;
if (trueExprCategory == ShaderNodes::ExpressionCategory::RValue || falseExprCategory == ShaderNodes::ExpressionCategory::RValue)
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
node.falsePath->Visit(*this);
ExpressionCategory falseExprCategory = m_expressionCategory;
if (trueExprCategory == ExpressionCategory::RValue || falseExprCategory == ExpressionCategory::RValue)
m_expressionCategory = ExpressionCategory::RValue;
else
m_expressionCategory = ShaderNodes::ExpressionCategory::LValue;
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Constant& node)
void ShaderAstValueCategory::Visit(ConstantExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Identifier& node)
void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::LValue;
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::IntrinsicCall& node)
void ShaderAstValueCategory::Visit(IntrinsicExpression& /*node*/)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
m_expressionCategory = ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::Sample2D& node)
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{
m_expressionCategory = ShaderNodes::ExpressionCategory::RValue;
}
void ShaderAstValueCategory::Visit(ShaderNodes::SwizzleOp& node)
{
Visit(node.expression);
node.expression->Visit(*this);
}
}

View File

@@ -4,48 +4,40 @@
#include <Nazara/Shader/ShaderAstValidator.hpp>
#include <Nazara/Core/CallOnExit.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/ShaderAstUtils.hpp>
#include <Nazara/Shader/ShaderVariables.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <vector>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
namespace Nz::ShaderAst
{
struct AstError
{
std::string errMsg;
};
struct ShaderAstValidator::Context
struct AstValidator::Context
{
struct Local
{
std::string name;
ShaderExpressionType type;
};
const ShaderAst::Function* currentFunction;
std::vector<Local> declaredLocals;
std::vector<std::size_t> blockLocalIndex;
//const ShaderAst::Function* currentFunction;
std::optional<std::size_t> activeScopeId;
AstCache* cache;
};
bool ShaderAstValidator::Validate(std::string* error)
bool AstValidator::Validate(StatementPtr& node, std::string* error, AstCache* cache)
{
try
{
for (std::size_t i = 0; i < m_shader.GetFunctionCount(); ++i)
{
const auto& func = m_shader.GetFunction(i);
AstCache dummy;
Context currentContext;
currentContext.currentFunction = &func;
Context currentContext;
currentContext.cache = (cache) ? cache : &dummy;
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
m_context = &currentContext;
CallOnExit resetContext([&] { m_context = nullptr; });
func.statement->Visit(*this);
}
EnterScope();
node->Visit(*this);
ExitScope();
return true;
}
@@ -58,148 +50,183 @@ namespace Nz
}
}
const ShaderNodes::ExpressionPtr& ShaderAstValidator::MandatoryExpr(const ShaderNodes::ExpressionPtr& node)
{
MandatoryNode(node);
return node;
}
const ShaderNodes::NodePtr& ShaderAstValidator::MandatoryNode(const ShaderNodes::NodePtr& node)
Expression& AstValidator::MandatoryExpr(ExpressionPtr& node)
{
if (!node)
throw AstError{ "Invalid node" };
throw AstError{ "Invalid expression" };
return node;
return *node;
}
void ShaderAstValidator::TypeMustMatch(const ShaderNodes::ExpressionPtr& left, const ShaderNodes::ExpressionPtr& right)
Statement& AstValidator::MandatoryStatement(StatementPtr& node)
{
return TypeMustMatch(left->GetExpressionType(), right->GetExpressionType());
if (!node)
throw AstError{ "Invalid statement" };
return *node;
}
void ShaderAstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
void AstValidator::TypeMustMatch(ExpressionPtr& left, ExpressionPtr& right)
{
return TypeMustMatch(GetExpressionType(*left, m_context->cache), GetExpressionType(*right, m_context->cache));
}
void AstValidator::TypeMustMatch(const ShaderExpressionType& left, const ShaderExpressionType& right)
{
if (left != right)
throw AstError{ "Left expression type must match right expression type" };
}
const ShaderAst::StructMember& ShaderAstValidator::CheckField(const std::string& structName, std::size_t* memberIndex, std::size_t remainingMembers)
ShaderExpressionType AstValidator::CheckField(const std::string& structName, const std::string* memberIdentifier, std::size_t remainingMembers)
{
const auto& structs = m_shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
if (it == structs.end())
throw AstError{ "invalid structure" };
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, structName);
if (!identifier)
throw AstError{ "unknown identifier " + structName };
const ShaderAst::Struct& s = *it;
if (*memberIndex >= s.members.size())
throw AstError{ "member index out of bounds" };
if (std::holds_alternative<StructDescription>(identifier->value))
throw AstError{ "identifier is not a struct" };
const auto& member = s.members[*memberIndex];
const StructDescription& s = std::get<StructDescription>(identifier->value);
auto memberIt = std::find_if(s.members.begin(), s.members.begin(), [&](const auto& field) { return field.name == memberIdentifier[0]; });
if (memberIt == s.members.end())
throw AstError{ "unknown field " + memberIdentifier[0]};
const auto& member = *memberIt;
if (remainingMembers > 1)
{
if (!IsStructType(member.type))
throw AstError{ "member type does not match node type" };
return CheckField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
}
return CheckField(std::get<std::string>(member.type), memberIdentifier + 1, remainingMembers - 1);
else
return member;
return member.type;
}
void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node)
AstCache::Scope& AstValidator::EnterScope()
{
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
std::size_t newScopeId = m_context->cache->scopes.size();
std::optional<std::size_t> previousScope = m_context->activeScopeId;
auto& newScope = m_context->cache->scopes.emplace_back();
newScope.parentScopeIndex = previousScope;
m_context->activeScopeId = newScopeId;
return m_context->cache->scopes[newScopeId];
}
void AstValidator::ExitScope()
{
assert(m_context->activeScopeId);
auto& previousScope = m_context->cache->scopes[*m_context->activeScopeId];
m_context->activeScopeId = previousScope.parentScopeIndex;
}
void AstValidator::RegisterExpressionType(Expression& node, ShaderExpressionType expressionType)
{
m_context->cache->nodeExpressionType[&node] = std::move(expressionType);
}
void AstValidator::RegisterScope(Node& node)
{
if (m_context->activeScopeId)
m_context->cache->scopeIdByNode[&node] = *m_context->activeScopeId;
}
void AstValidator::Visit(AccessMemberExpression& node)
{
RegisterScope(node);
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.structExpr), m_context->cache);
if (!IsStructType(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<std::string>(exprType);
const auto& member = CheckField(structName, node.memberIndices.data(), node.memberIndices.size());
if (member.type != node.exprType)
throw AstError{ "member type does not match node type" };
RegisterExpressionType(node, CheckField(structName, node.memberIdentifiers.data(), node.memberIdentifiers.size()));
}
void ShaderAstValidator::Visit(ShaderNodes::AssignOp& node)
void AstValidator::Visit(AssignExpression& node)
{
MandatoryNode(node.left);
MandatoryNode(node.right);
RegisterScope(node);
MandatoryExpr(node.left);
MandatoryExpr(node.right);
TypeMustMatch(node.left, node.right);
if (GetExpressionCategory(node.left) != ShaderNodes::ExpressionCategory::LValue)
if (GetExpressionCategory(*node.left) != ExpressionCategory::LValue)
throw AstError { "Assignation is only possible with a l-value" };
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::BinaryOp& node)
void AstValidator::Visit(BinaryExpression& node)
{
MandatoryNode(node.left);
MandatoryNode(node.right);
RegisterScope(node);
const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
// Register expression type
AstRecursiveVisitor::Visit(node);
const ShaderExpressionType& leftExprType = GetExpressionType(MandatoryExpr(node.left), m_context->cache);
if (!IsBasicType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
const ShaderExpressionType& rightExprType = GetExpressionType(MandatoryExpr(node.right), m_context->cache);
if (!IsBasicType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
BasicType leftType = std::get<BasicType>(leftExprType);
BasicType rightType = std::get<BasicType>(rightExprType);
switch (node.op)
{
case ShaderNodes::BinaryType::CompGe:
case ShaderNodes::BinaryType::CompGt:
case ShaderNodes::BinaryType::CompLe:
case ShaderNodes::BinaryType::CompLt:
if (leftType == ShaderNodes::BasicType::Boolean)
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
if (leftType == BasicType::Boolean)
throw AstError{ "this operation is not supported for booleans" };
[[fallthrough]];
case ShaderNodes::BinaryType::Add:
case ShaderNodes::BinaryType::CompEq:
case ShaderNodes::BinaryType::CompNe:
case ShaderNodes::BinaryType::Subtract:
case BinaryType::Add:
case BinaryType::CompEq:
case BinaryType::CompNe:
case BinaryType::Subtract:
TypeMustMatch(node.left, node.right);
break;
case ShaderNodes::BinaryType::Multiply:
case ShaderNodes::BinaryType::Divide:
case BinaryType::Multiply:
case BinaryType::Divide:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Int1:
case BasicType::Float1:
case BasicType::Int1:
{
if (ShaderNodes::Node::GetComponentType(rightType) != leftType)
if (GetComponentType(rightType) != leftType)
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
{
if (leftType != rightType && rightType != ShaderNodes::Node::GetComponentType(leftType))
if (leftType != rightType && rightType != GetComponentType(leftType))
throw AstError{ "Left expression type is not compatible with right expression type" };
break;
}
case ShaderNodes::BasicType::Mat4x4:
case BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case BasicType::Float1:
case BasicType::Float4:
case BasicType::Mat4x4:
break;
default:
@@ -211,120 +238,86 @@ namespace Nz
default:
TypeMustMatch(node.left, node.right);
break;
}
}
}
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Branch& node)
void AstValidator::Visit(CastExpression& node)
{
for (const auto& condStatement : node.condStatements)
{
const ShaderExpressionType& condType = MandatoryExpr(condStatement.condition)->GetExpressionType();
if (!IsBasicType(condType) || std::get<ShaderNodes::BasicType>(condType) != ShaderNodes::BasicType::Boolean)
throw AstError{ "if expression must resolve to boolean type" };
RegisterScope(node);
MandatoryNode(condStatement.statement);
}
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Cast& node)
{
unsigned int componentCount = 0;
unsigned int requiredComponents = node.GetComponentCount(node.exprType);
for (const auto& exprPtr : node.expressions)
unsigned int requiredComponents = GetComponentCount(node.targetType);
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
ShaderExpressionType exprType = GetExpressionType(*exprPtr, m_context->cache);
if (!IsBasicType(exprType))
throw AstError{ "incompatible type" };
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
componentCount += GetComponentCount(std::get<BasicType>(exprType));
}
if (componentCount != requiredComponents)
throw AstError{ "component count doesn't match required component count" };
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::ConditionalExpression& node)
void AstValidator::Visit(ConditionalExpression& node)
{
MandatoryNode(node.truePath);
MandatoryNode(node.falsePath);
MandatoryExpr(node.truePath);
MandatoryExpr(node.falsePath);
if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
throw AstError{ "condition not found" };
RegisterScope(node);
AstRecursiveVisitor::Visit(node);
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
// throw AstError{ "condition not found" };
}
void ShaderAstValidator::Visit(ShaderNodes::ConditionalStatement& node)
void AstValidator::Visit(ConstantExpression& node)
{
MandatoryNode(node.statement);
if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
throw AstError{ "condition not found" };
RegisterScope(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Constant& /*node*/)
{
}
void ShaderAstValidator::Visit(ShaderNodes::DeclareVariable& node)
void AstValidator::Visit(IdentifierExpression& node)
{
assert(m_context);
if (node.variable->GetType() != ShaderNodes::VariableType::LocalVariable)
throw AstError{ "Only local variables can be declared in a statement" };
if (!m_context->activeScopeId)
throw AstError{ "no scope" };
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
RegisterScope(node);
auto& local = m_context->declaredLocals.emplace_back();
local.name = localVar.name;
local.type = localVar.type;
ShaderAstRecursiveVisitor::Visit(node);
const AstCache::Identifier* identifier = m_context->cache->FindIdentifier(*m_context->activeScopeId, node.identifier);
if (!identifier)
throw AstError{ "Unknown variable " + node.identifier };
}
void ShaderAstValidator::Visit(ShaderNodes::ExpressionStatement& node)
void AstValidator::Visit(IntrinsicExpression& node)
{
MandatoryNode(node.expression);
RegisterScope(node);
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Identifier& node)
{
assert(m_context);
if (!node.var)
throw AstError{ "Invalid variable" };
Visit(node.var);
}
void ShaderAstValidator::Visit(ShaderNodes::IntrinsicCall& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderNodes::IntrinsicType::DotProduct:
case IntrinsicType::CrossProduct:
case IntrinsicType::DotProduct:
{
if (node.parameters.size() != 2)
throw AstError { "Expected 2 parameters" };
for (auto& param : node.parameters)
MandatoryNode(param);
MandatoryExpr(param);
ShaderExpressionType type = node.parameters.front()->GetExpressionType();
ShaderExpressionType type = GetExpressionType(*node.parameters.front(), m_context->cache);
for (std::size_t i = 1; i < node.parameters.size(); ++i)
{
if (type != node.parameters[i]->GetExpressionType())
if (type != GetExpressionType(MandatoryExpr(node.parameters[i])), m_context->cache)
throw AstError{ "All type must match" };
}
@@ -334,180 +327,176 @@ namespace Nz
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::CrossProduct:
case IntrinsicType::CrossProduct:
{
if (node.parameters[0]->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float3 })
if (GetExpressionType(*node.parameters[0]) != ShaderExpressionType{ BasicType::Float3 }, m_context->cache)
throw AstError{ "CrossProduct only works with Float3 expressions" };
break;
}
case ShaderNodes::IntrinsicType::DotProduct:
case IntrinsicType::DotProduct:
break;
}
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::ReturnStatement& node)
void AstValidator::Visit(SwizzleExpression& node)
{
if (m_context->currentFunction->returnType != ShaderExpressionType(ShaderNodes::BasicType::Void))
{
if (MandatoryExpr(node.returnExpr)->GetExpressionType() != m_context->currentFunction->returnType)
throw AstError{ "Return type doesn't match function return type" };
}
else
{
if (node.returnExpr)
throw AstError{ "Unexpected expression for return (function doesn't return)" };
}
RegisterScope(node);
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Sample2D& node)
{
if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D })
throw AstError{ "Sampler must be a Sampler2D" };
if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float2 })
throw AstError{ "Coordinates must be a Float2" };
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::StatementBlock& node)
{
assert(m_context);
m_context->blockLocalIndex.push_back(m_context->declaredLocals.size());
for (const auto& statement : node.statements)
MandatoryNode(statement);
assert(m_context->declaredLocals.size() >= m_context->blockLocalIndex.back());
m_context->declaredLocals.resize(m_context->blockLocalIndex.back());
m_context->blockLocalIndex.pop_back();
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::SwizzleOp& node)
{
if (node.componentCount > 4)
throw AstError{ "Cannot swizzle more than four elements" };
const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
const ShaderExpressionType& exprType = GetExpressionType(MandatoryExpr(node.expression), m_context->cache);
if (!IsBasicType(exprType))
throw AstError{ "Cannot swizzle this type" };
switch (std::get<ShaderNodes::BasicType>(exprType))
switch (std::get<BasicType>(exprType))
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
break;
default:
throw AstError{ "Cannot swizzle this type" };
}
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::BuiltinVariable& var)
void AstValidator::Visit(BranchStatement& node)
{
switch (var.entry)
RegisterScope(node);
for (auto& condStatement : node.condStatements)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
if (!IsBasicType(var.type) ||
std::get<ShaderNodes::BasicType>(var.type) != ShaderNodes::BasicType::Float4)
throw AstError{ "Builtin is not of the expected type" };
const ShaderExpressionType& condType = GetExpressionType(MandatoryExpr(condStatement.condition), m_context->cache);
if (!IsBasicType(condType) || std::get<BasicType>(condType) != BasicType::Boolean)
throw AstError{ "if expression must resolve to boolean type" };
break;
default:
break;
}
}
void ShaderAstValidator::Visit(ShaderNodes::InputVariable& var)
{
for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i)
{
const auto& input = m_shader.GetInput(i);
if (input.name == var.name)
{
TypeMustMatch(input.type, var.type);
return;
}
MandatoryStatement(condStatement.statement);
}
throw AstError{ "Input not found" };
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::LocalVariable& var)
void AstValidator::Visit(ConditionalStatement& node)
{
const auto& vars = m_context->declaredLocals;
MandatoryStatement(node.statement);
auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& v) { return v.name == var.name; });
if (it == vars.end())
throw AstError{ "Local variable not found in this block" };
RegisterScope(node);
TypeMustMatch(it->type, var.type);
AstRecursiveVisitor::Visit(node);
//if (m_shader.FindConditionByName(node.conditionName) == ShaderAst::InvalidCondition)
// throw AstError{ "condition not found" };
}
void ShaderAstValidator::Visit(ShaderNodes::OutputVariable& var)
void AstValidator::Visit(DeclareFunctionStatement& node)
{
for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i)
auto& scope = EnterScope();
RegisterScope(node);
for (auto& parameter : node.parameters)
{
const auto& input = m_shader.GetOutput(i);
if (input.name == var.name)
{
TypeMustMatch(input.type, var.type);
return;
}
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ parameter.name, AstCache::Variable { parameter.type } };
}
throw AstError{ "Output not found" };
for (auto& statement : node.statements)
MandatoryStatement(statement).Visit(*this);
ExitScope();
}
void ShaderAstValidator::Visit(ShaderNodes::ParameterVariable& var)
void AstValidator::Visit(DeclareStructStatement& node)
{
assert(m_context->currentFunction);
assert(m_context);
const auto& parameters = m_context->currentFunction->parameters;
if (!m_context->activeScopeId)
throw AstError{ "cannot declare variable without scope" };
auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == var.name; });
if (it == parameters.end())
throw AstError{ "Parameter not found in function" };
RegisterScope(node);
TypeMustMatch(it->type, var.type);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ node.description.name, node.description };
AstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::UniformVariable& var)
void AstValidator::Visit(DeclareVariableStatement& node)
{
for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i)
assert(m_context);
if (!m_context->activeScopeId)
throw AstError{ "cannot declare variable without scope" };
RegisterScope(node);
auto& scope = m_context->cache->scopes[*m_context->activeScopeId];
auto& identifier = scope.identifiers.emplace_back();
identifier = AstCache::Identifier{ node.varName, AstCache::Variable { node.varType } };
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(ExpressionStatement& node)
{
RegisterScope(node);
MandatoryExpr(node.expression);
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(MultiStatement& node)
{
assert(m_context);
EnterScope();
RegisterScope(node);
for (auto& statement : node.statements)
MandatoryStatement(statement);
ExitScope();
AstRecursiveVisitor::Visit(node);
}
void AstValidator::Visit(ReturnStatement& node)
{
RegisterScope(node);
/*if (m_context->currentFunction->returnType != ShaderExpressionType(BasicType::Void))
{
const auto& uniform = m_shader.GetUniform(i);
if (uniform.name == var.name)
{
TypeMustMatch(uniform.type, var.type);
return;
}
if (GetExpressionType(MandatoryExpr(node.returnExpr)) != m_context->currentFunction->returnType)
throw AstError{ "Return type doesn't match function return type" };
}
else
{
if (node.returnExpr)
throw AstError{ "Unexpected expression for return (function doesn't return)" };
}*/
throw AstError{ "Uniform not found" };
AstRecursiveVisitor::Visit(node);
}
bool ValidateShader(const ShaderAst& shader, std::string* error)
bool ValidateAst(StatementPtr& node, std::string* error, AstCache* cache)
{
ShaderAstValidator validator(shader);
return validator.Validate(error);
AstValidator validator;
return validator.Validate(node, error, cache);
}
}

View File

@@ -1,100 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderAstVisitorExcept.hpp>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void ShaderAstVisitorExcept::Visit(ShaderNodes::AccessMember& /*node*/)
{
throw std::runtime_error("unhandled AccessMember node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::AssignOp& /*node*/)
{
throw std::runtime_error("unhandled AssignOp node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::BinaryOp& /*node*/)
{
throw std::runtime_error("unhandled AccessMember node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Branch& /*node*/)
{
throw std::runtime_error("unhandled Branch node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Cast& /*node*/)
{
throw std::runtime_error("unhandled Cast node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalExpression& /*node*/)
{
throw std::runtime_error("unhandled ConditionalExpression node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ConditionalStatement& /*node*/)
{
throw std::runtime_error("unhandled ConditionalStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Constant& /*node*/)
{
throw std::runtime_error("unhandled Constant node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::DeclareVariable& /*node*/)
{
throw std::runtime_error("unhandled DeclareVariable node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Discard& /*node*/)
{
throw std::runtime_error("unhandled Discard node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ExpressionStatement& /*node*/)
{
throw std::runtime_error("unhandled ExpressionStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Identifier& /*node*/)
{
throw std::runtime_error("unhandled Identifier node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::IntrinsicCall& /*node*/)
{
throw std::runtime_error("unhandled IntrinsicCall node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::NoOp& node)
{
throw std::runtime_error("unhandled NoOp node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ReturnStatement& node)
{
throw std::runtime_error("unhandled ReturnStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Sample2D& /*node*/)
{
throw std::runtime_error("unhandled Sample2D node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::StatementBlock& /*node*/)
{
throw std::runtime_error("unhandled StatementBlock node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::SwizzleOp& /*node*/)
{
throw std::runtime_error("unhandled SwizzleOp node");
}
}

View File

@@ -42,6 +42,7 @@ namespace Nz::ShaderLang
std::unordered_map<std::string, TokenType> reservedKeywords = {
{ "false", TokenType::BoolFalse },
{ "fn", TokenType::FunctionDeclaration },
{ "let", TokenType::Let },
{ "return", TokenType::Return },
{ "true", TokenType::BoolTrue }
};
@@ -143,7 +144,7 @@ namespace Nz::ShaderLang
while (next != -1);
}
else
tokenType == TokenType::Divide;
tokenType = TokenType::Divide;
break;
}
@@ -191,9 +192,11 @@ namespace Nz::ShaderLang
std::string valueStr(str.substr(start, currentPos - start + 1));
const char* ptr = valueStr.c_str();
char* end;
double value = std::strtod(valueStr.c_str(), &end);
if (end != &str[currentPos + 1])
double value = std::strtod(ptr, &end);
if (end != &ptr[valueStr.size()])
throw BadNumber{};
token.data = value;
@@ -218,6 +221,7 @@ namespace Nz::ShaderLang
break;
}
case '=': tokenType = TokenType::Assign; break;
case '+': tokenType = TokenType::Plus; break;
case '*': tokenType = TokenType::Multiply; break;
case ':': tokenType = TokenType::Colon; break;

View File

@@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderLangParser.hpp>
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <cassert>
#include <Nazara/Shader/Debug.hpp>
@@ -10,36 +11,38 @@ namespace Nz::ShaderLang
{
namespace
{
std::unordered_map<std::string, ShaderNodes::BasicType> identifierToBasicType = {
{ "bool", ShaderNodes::BasicType::Boolean },
std::unordered_map<std::string, ShaderAst::BasicType> identifierToBasicType = {
{ "bool", ShaderAst::BasicType::Boolean },
{ "i32", ShaderNodes::BasicType::Int1 },
{ "vec2i32", ShaderNodes::BasicType::Int2 },
{ "vec3i32", ShaderNodes::BasicType::Int3 },
{ "vec4i32", ShaderNodes::BasicType::Int4 },
{ "i32", ShaderAst::BasicType::Int1 },
{ "vec2i32", ShaderAst::BasicType::Int2 },
{ "vec3i32", ShaderAst::BasicType::Int3 },
{ "vec4i32", ShaderAst::BasicType::Int4 },
{ "f32", ShaderNodes::BasicType::Float1 },
{ "vec2f32", ShaderNodes::BasicType::Float2 },
{ "vec3f32", ShaderNodes::BasicType::Float3 },
{ "vec4f32", ShaderNodes::BasicType::Float4 },
{ "f32", ShaderAst::BasicType::Float1 },
{ "vec2f32", ShaderAst::BasicType::Float2 },
{ "vec3f32", ShaderAst::BasicType::Float3 },
{ "vec4f32", ShaderAst::BasicType::Float4 },
{ "mat4x4f32", ShaderNodes::BasicType::Mat4x4 },
{ "sampler2D", ShaderNodes::BasicType::Sampler2D },
{ "void", ShaderNodes::BasicType::Void },
{ "mat4x4f32", ShaderAst::BasicType::Mat4x4 },
{ "sampler2D", ShaderAst::BasicType::Sampler2D },
{ "void", ShaderAst::BasicType::Void },
{ "u32", ShaderNodes::BasicType::UInt1 },
{ "vec2u32", ShaderNodes::BasicType::UInt3 },
{ "vec3u32", ShaderNodes::BasicType::UInt3 },
{ "vec4u32", ShaderNodes::BasicType::UInt4 },
{ "u32", ShaderAst::BasicType::UInt1 },
{ "vec2u32", ShaderAst::BasicType::UInt3 },
{ "vec3u32", ShaderAst::BasicType::UInt3 },
{ "vec4u32", ShaderAst::BasicType::UInt4 },
};
}
ShaderAst Parser::Parse(const std::vector<Token>& tokens)
ShaderAst::StatementPtr Parser::Parse(const std::vector<Token>& tokens)
{
Context context;
context.tokenCount = tokens.size();
context.tokens = tokens.data();
context.root = std::make_unique<ShaderAst::MultiStatement>();
m_context = &context;
m_context->tokenIndex = -1;
@@ -51,7 +54,7 @@ namespace Nz::ShaderLang
switch (nextToken.type)
{
case TokenType::FunctionDeclaration:
ParseFunctionDeclaration();
context.root->statements.push_back(ParseFunctionDeclaration());
break;
case TokenType::EndOfStream:
@@ -63,7 +66,7 @@ namespace Nz::ShaderLang
}
}
return std::move(context.result);
return std::move(context.root);
}
const Token& Parser::Advance()
@@ -92,12 +95,12 @@ namespace Nz::ShaderLang
return m_context->tokens[m_context->tokenIndex + 1];
}
ShaderNodes::StatementPtr Parser::ParseFunctionBody()
std::vector<ShaderAst::StatementPtr> Parser::ParseFunctionBody()
{
return ParseStatementList();
}
void Parser::ParseFunctionDeclaration()
ShaderAst::StatementPtr Parser::ParseFunctionDeclaration()
{
ExpectNext(TokenType::FunctionDeclaration);
@@ -105,7 +108,7 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::OpenParenthesis);
std::vector<ShaderAst::FunctionParameter> parameters;
std::vector<ShaderAst::DeclareFunctionStatement::Parameter> parameters;
bool firstParameter = true;
for (;;)
@@ -126,7 +129,7 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::ClosingParenthesis);
ShaderExpressionType returnType = ShaderNodes::BasicType::Void;
ShaderAst::ShaderExpressionType returnType = ShaderAst::BasicType::Void;
if (PeekNext().type == TokenType::FunctionReturn)
{
Advance(); //< Consume ->
@@ -136,42 +139,46 @@ namespace Nz::ShaderLang
ExpectNext(TokenType::OpenCurlyBracket);
ShaderNodes::StatementPtr functionBody = ParseFunctionBody();
std::vector<ShaderAst::StatementPtr> functionBody = ParseFunctionBody();
ExpectNext(TokenType::ClosingCurlyBracket);
m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType);
return ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType));
}
ShaderAst::FunctionParameter Parser::ParseFunctionParameter()
ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter()
{
std::string parameterName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
ShaderExpressionType parameterType = ParseIdentifierAsType();
ShaderAst::ShaderExpressionType parameterType = ParseIdentifierAsType();
return { parameterName, parameterType };
}
ShaderNodes::StatementPtr Parser::ParseReturnStatement()
ShaderAst::StatementPtr Parser::ParseReturnStatement()
{
ExpectNext(TokenType::Return);
ShaderNodes::ExpressionPtr expr;
ShaderAst::ExpressionPtr expr;
if (PeekNext().type != TokenType::Semicolon)
expr = ParseExpression();
return ShaderNodes::ReturnStatement::Build(std::move(expr));
return ShaderBuilder::Return(std::move(expr));
}
ShaderNodes::StatementPtr Parser::ParseStatement()
ShaderAst::StatementPtr Parser::ParseStatement()
{
const Token& token = PeekNext();
ShaderNodes::StatementPtr statement;
ShaderAst::StatementPtr statement;
switch (token.type)
{
case TokenType::Let:
statement = ParseVariableDeclaration();
break;
case TokenType::Return:
statement = ParseReturnStatement();
break;
@@ -185,18 +192,38 @@ namespace Nz::ShaderLang
return statement;
}
ShaderNodes::StatementPtr Parser::ParseStatementList()
std::vector<ShaderAst::StatementPtr> Parser::ParseStatementList()
{
std::vector<ShaderNodes::StatementPtr> statements;
std::vector<ShaderAst::StatementPtr> statements;
while (PeekNext().type != TokenType::ClosingCurlyBracket)
{
statements.push_back(ParseStatement());
}
return ShaderNodes::StatementBlock::Build(std::move(statements));
return statements;
}
ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs)
ShaderAst::StatementPtr Parser::ParseVariableDeclaration()
{
ExpectNext(TokenType::Let);
std::string variableName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
ShaderAst::ShaderExpressionType variableType = ParseIdentifierAsType();
ShaderAst::ExpressionPtr expression;
if (PeekNext().type == TokenType::Assign)
{
Advance();
expression = ParseExpression();
}
return ShaderBuilder::DeclareVariable(std::move(variableName), std::move(variableType), std::move(expression));
}
ShaderAst::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs)
{
for (;;)
{
@@ -207,7 +234,7 @@ namespace Nz::ShaderLang
return lhs;
Advance();
ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression();
ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression();
const Token& nextOp = PeekNext();
@@ -215,57 +242,58 @@ namespace Nz::ShaderLang
if (tokenPrecedence < nextTokenPrecedence)
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
ShaderNodes::BinaryType binaryType;
ShaderAst::BinaryType binaryType;
{
switch (currentOp.type)
{
case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break;
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
default: throw UnexpectedToken{};
}
}
lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs));
lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
}
}
ShaderNodes::ExpressionPtr Parser::ParseExpression()
ShaderAst::ExpressionPtr Parser::ParseExpression()
{
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
ShaderNodes::ExpressionPtr Parser::ParseIdentifier()
ShaderAst::ExpressionPtr Parser::ParseIdentifier()
{
const Token& identifier = ExpectNext(TokenType::Identifier);
return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get<std::string>(identifier.data), ShaderNodes::BasicType::Float3));
return ShaderBuilder::Identifier(std::get<std::string>(identifier.data));
}
ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression()
ShaderAst::ExpressionPtr Parser::ParseIntegerExpression()
{
const Token& integer = ExpectNext(TokenType::IntegerValue);
return ShaderNodes::Constant::Build(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
return ShaderBuilder::Constant(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
}
ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression()
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
{
ExpectNext(TokenType::OpenParenthesis);
ShaderNodes::ExpressionPtr expression = ParseExpression();
ShaderAst::ExpressionPtr expression = ParseExpression();
ExpectNext(TokenType::ClosingParenthesis);
return expression;
}
ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression()
ShaderAst::ExpressionPtr Parser::ParsePrimaryExpression()
{
const Token& token = PeekNext();
switch (token.type)
{
case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false);
case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true);
case TokenType::BoolFalse: return ShaderBuilder::Constant(false);
case TokenType::BoolTrue: return ShaderBuilder::Constant(true);
case TokenType::FloatingPointValue: return ShaderBuilder::Constant(float(std::get<double>(Advance().data))); //< FIXME
case TokenType::Identifier: return ParseIdentifier();
case TokenType::IntegerValue: return ParseIntegerExpression();
case TokenType::OpenParenthesis: return ParseParenthesisExpression();
@@ -286,7 +314,7 @@ namespace Nz::ShaderLang
return identifier;
}
ShaderExpressionType Parser::ParseIdentifierAsType()
ShaderAst::ShaderExpressionType Parser::ParseIdentifierAsType()
{
const Token& identifier = ExpectNext(TokenType::Identifier);

View File

@@ -4,265 +4,29 @@
#include <Nazara/Shader/ShaderNodes.hpp>
#include <Nazara/Core/Algorithm.hpp>
#include <Nazara/Shader/ShaderAstSerializer.hpp>
#include <Nazara/Shader/ShaderAstVisitor.hpp>
#include <Nazara/Shader/ShaderWriter.hpp>
#include <Nazara/Shader/ShaderAstExpressionVisitor.hpp>
#include <Nazara/Shader/ShaderAstStatementVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderNodes
namespace Nz::ShaderAst
{
Node::~Node() = default;
void ExpressionStatement::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
#define NAZARA_SHADERAST_NODE(Node) NodeType Node::GetType() const \
{ \
return NodeType:: Node; \
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
#define NAZARA_SHADERAST_EXPRESSION(Node) void Node::Visit(AstExpressionVisitor& visitor) \
{\
visitor.Visit(*this); \
}
void ConditionalStatement::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
#define NAZARA_SHADERAST_STATEMENT(Node) void Node::Visit(AstStatementVisitor& visitor) \
{\
visitor.Visit(*this); \
}
void StatementBlock::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void DeclareVariable::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void Discard::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Identifier::GetExpressionType() const
{
assert(var);
return var->type;
}
void Identifier::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType AccessMember::GetExpressionType() const
{
return exprType;
}
void AccessMember::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void NoOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void ReturnStatement::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType AssignOp::GetExpressionType() const
{
return left->GetExpressionType();
}
void AssignOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType BinaryOp::GetExpressionType() const
{
std::optional<ShaderExpressionType> exprType;
switch (op)
{
case BinaryType::Add:
case BinaryType::Subtract:
exprType = left->GetExpressionType();
break;
case BinaryType::Divide:
case BinaryType::Multiply:
{
const ShaderExpressionType& leftExprType = left->GetExpressionType();
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = right->GetExpressionType();
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
{
case BasicType::Boolean:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
exprType = leftExprType;
break;
case BasicType::Float1:
case BasicType::Int1:
case BasicType::Mat4x4:
case BasicType::UInt1:
exprType = rightExprType;
break;
case BasicType::Sampler2D:
case BasicType::Void:
break;
}
break;
}
case BinaryType::CompEq:
case BinaryType::CompGe:
case BinaryType::CompGt:
case BinaryType::CompLe:
case BinaryType::CompLt:
case BinaryType::CompNe:
exprType = BasicType::Boolean;
break;
}
NazaraAssert(exprType.has_value(), "Unhandled builtin");
return *exprType;
}
void BinaryOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
void Branch::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Constant::GetExpressionType() const
{
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, bool>)
return ShaderNodes::BasicType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return ShaderNodes::BasicType::Float1;
else if constexpr (std::is_same_v<T, Int32>)
return ShaderNodes::BasicType::Int1;
else if constexpr (std::is_same_v<T, UInt32>)
return ShaderNodes::BasicType::Int1;
else if constexpr (std::is_same_v<T, Vector2f>)
return ShaderNodes::BasicType::Float2;
else if constexpr (std::is_same_v<T, Vector3f>)
return ShaderNodes::BasicType::Float3;
else if constexpr (std::is_same_v<T, Vector4f>)
return ShaderNodes::BasicType::Float4;
else if constexpr (std::is_same_v<T, Vector2i32>)
return ShaderNodes::BasicType::Int2;
else if constexpr (std::is_same_v<T, Vector3i32>)
return ShaderNodes::BasicType::Int3;
else if constexpr (std::is_same_v<T, Vector4i32>)
return ShaderNodes::BasicType::Int4;
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
}, value);
}
void Constant::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Cast::GetExpressionType() const
{
return exprType;
}
void Cast::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType ConditionalExpression::GetExpressionType() const
{
assert(truePath->GetExpressionType() == falsePath->GetExpressionType());
return truePath->GetExpressionType();
}
void ConditionalExpression::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType SwizzleOp::GetExpressionType() const
{
const ShaderExpressionType& exprType = expression->GetExpressionType();
assert(IsBasicType(exprType));
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
}
void SwizzleOp::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType Sample2D::GetExpressionType() const
{
return BasicType::Float4;
}
void Sample2D::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType IntrinsicCall::GetExpressionType() const
{
switch (intrinsic)
{
case IntrinsicType::CrossProduct:
return parameters.front()->GetExpressionType();
case IntrinsicType::DotProduct:
return BasicType::Float1;
}
NazaraAssert(false, "Unhandled builtin");
return BasicType::Void;
}
void IntrinsicCall::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
#include <Nazara/Shader/ShaderAstNodes.hpp>
}

View File

@@ -1,40 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderVarVisitorExcept.hpp>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
void ShaderVarVisitorExcept::Visit(ShaderNodes::BuiltinVariable& /*var*/)
{
throw std::runtime_error("unhandled BuiltinVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::InputVariable& /*var*/)
{
throw std::runtime_error("unhandled InputVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::LocalVariable& /*var*/)
{
throw std::runtime_error("unhandled LocalVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::OutputVariable& /*var*/)
{
throw std::runtime_error("unhandled OutputVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::ParameterVariable& /*var*/)
{
throw std::runtime_error("unhandled ParameterVariable");
}
void ShaderVarVisitorExcept::Visit(ShaderNodes::UniformVariable& /*var*/)
{
throw std::runtime_error("unhandled UniformVariable");
}
}

View File

@@ -1,77 +0,0 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderVariables.hpp>
#include <Nazara/Shader/ShaderVarVisitor.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderNodes
{
ShaderNodes::Variable::~Variable() = default;
VariableType BuiltinVariable::GetType() const
{
return VariableType::BuiltinVariable;
}
void BuiltinVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType InputVariable::GetType() const
{
return VariableType::InputVariable;
}
void InputVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType LocalVariable::GetType() const
{
return VariableType::LocalVariable;
}
void LocalVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType OutputVariable::GetType() const
{
return VariableType::OutputVariable;
}
void OutputVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType ParameterVariable::GetType() const
{
return VariableType::ParameterVariable;
}
void ParameterVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
VariableType UniformVariable::GetType() const
{
return VariableType::UniformVariable;
}
void UniformVariable::Visit(ShaderVarVisitor& visitor)
{
visitor.Visit(*this);
}
}

View File

@@ -4,6 +4,7 @@
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Core/StackVector.hpp>
#include <Nazara/Shader/ShaderAstExpressionType.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvExpressionLoad.hpp>
#include <Nazara/Shader/SpirvExpressionStore.hpp>
@@ -12,21 +13,21 @@
namespace Nz
{
UInt32 SpirvAstVisitor::EvaluateExpression(const ShaderNodes::ExpressionPtr& expr)
UInt32 SpirvAstVisitor::EvaluateExpression(ShaderAst::ExpressionPtr& expr)
{
Visit(expr);
expr->Visit(*this);
assert(m_resultIds.size() == 1);
return PopResultId();
}
void SpirvAstVisitor::Visit(ShaderNodes::AccessMember& node)
void SpirvAstVisitor::Visit(ShaderAst::AccessMemberExpression& node)
{
SpirvExpressionLoad accessMemberVisitor(m_writer, *m_currentBlock);
PushResultId(accessMemberVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderNodes::AssignOp& node)
void SpirvAstVisitor::Visit(ShaderAst::AssignExpression& node)
{
UInt32 resultId = EvaluateExpression(node.right);
@@ -36,20 +37,20 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node)
void SpirvAstVisitor::Visit(ShaderAst::BinaryExpression& node)
{
ShaderExpressionType resultExprType = node.GetExpressionType();
ShaderAst::ShaderExpressionType resultExprType = ShaderAst::GetExpressionType(node);
assert(IsBasicType(resultExprType));
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
ShaderAst::ShaderExpressionType leftExprType = ShaderAst::GetExpressionType(*node.left);
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
ShaderAst::ShaderExpressionType rightExprType = ShaderAst::GetExpressionType(*node.right);
assert(IsBasicType(rightExprType));
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
ShaderAst::BasicType resultType = std::get<ShaderAst::BasicType>(resultExprType);
ShaderAst::BasicType leftType = std::get<ShaderAst::BasicType>(leftExprType);
ShaderAst::BasicType rightType = std::get<ShaderAst::BasicType>(rightExprType);
UInt32 leftOperand = EvaluateExpression(node.left);
@@ -62,308 +63,308 @@ namespace Nz
{
switch (node.op)
{
case ShaderNodes::BinaryType::Add:
case ShaderAst::BinaryType::Add:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFAdd;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpIAdd;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::Subtract:
case ShaderAst::BinaryType::Subtract:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFSub;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpISub;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::Divide:
case ShaderAst::BinaryType::Divide:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFDiv;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSDiv;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpUDiv;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompEq:
case ShaderAst::BinaryType::CompEq:
{
switch (leftType)
{
case ShaderNodes::BasicType::Boolean:
case ShaderAst::BasicType::Boolean:
return SpirvOp::OpLogicalEqual;
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpIEqual;
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompGe:
case ShaderAst::BinaryType::CompGe:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdGreaterThan;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSGreaterThan;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpUGreaterThan;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompGt:
case ShaderAst::BinaryType::CompGt:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdGreaterThanEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSGreaterThanEqual;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpUGreaterThanEqual;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompLe:
case ShaderAst::BinaryType::CompLe:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdLessThanEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSLessThanEqual;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpULessThanEqual;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompLt:
case ShaderAst::BinaryType::CompLt:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdLessThan;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
return SpirvOp::OpSLessThan;
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpULessThan;
case ShaderNodes::BasicType::Boolean:
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Boolean:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::CompNe:
case ShaderAst::BinaryType::CompNe:
{
switch (leftType)
{
case ShaderNodes::BasicType::Boolean:
case ShaderAst::BasicType::Boolean:
return SpirvOp::OpLogicalNotEqual;
case ShaderNodes::BasicType::Float1:
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Float1:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpFOrdNotEqual;
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpINotEqual;
case ShaderNodes::BasicType::Sampler2D:
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Sampler2D:
case ShaderAst::BasicType::Void:
break;
}
break;
}
case ShaderNodes::BinaryType::Multiply:
case ShaderAst::BinaryType::Multiply:
{
switch (leftType)
{
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
swapOperands = true;
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Mat4x4:
swapOperands = true;
return SpirvOp::OpMatrixTimesScalar;
@@ -374,21 +375,21 @@ namespace Nz
break;
}
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
return SpirvOp::OpVectorTimesScalar;
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
return SpirvOp::OpFMul;
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Mat4x4:
return SpirvOp::OpVectorTimesMatrix;
default:
@@ -398,23 +399,23 @@ namespace Nz
break;
}
case ShaderNodes::BasicType::Int1:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt1:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Int1:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt1:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
return SpirvOp::OpIMul;
case ShaderNodes::BasicType::Mat4x4:
case ShaderAst::BasicType::Mat4x4:
{
switch (rightType)
{
case ShaderNodes::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderNodes::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderNodes::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
case ShaderAst::BasicType::Float1: return SpirvOp::OpMatrixTimesScalar;
case ShaderAst::BasicType::Float4: return SpirvOp::OpMatrixTimesVector;
case ShaderAst::BasicType::Mat4x4: return SpirvOp::OpMatrixTimesMatrix;
default:
break;
@@ -442,7 +443,7 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::Branch& node)
void SpirvAstVisitor::Visit(ShaderAst::BranchStatement& node)
{
assert(!node.condStatements.empty());
auto& firstCond = node.condStatements.front();
@@ -450,7 +451,8 @@ namespace Nz
UInt32 previousConditionId = EvaluateExpression(firstCond.condition);
SpirvBlock previousContentBlock(m_writer);
m_currentBlock = &previousContentBlock;
Visit(firstCond.statement);
firstCond.statement->Visit(*this);
SpirvBlock mergeBlock(m_writer);
m_blocks.back().Append(SpirvOp::OpSelectionMerge, mergeBlock.GetLabelId(), SpirvSelectionControl::None);
@@ -458,7 +460,7 @@ namespace Nz
std::optional<std::size_t> nextBlock;
for (std::size_t statementIndex = 1; statementIndex < node.condStatements.size(); ++statementIndex)
{
const auto& statement = node.condStatements[statementIndex];
auto& statement = node.condStatements[statementIndex];
SpirvBlock contentBlock(m_writer);
@@ -469,7 +471,8 @@ namespace Nz
previousContentBlock = std::move(contentBlock);
m_currentBlock = &previousContentBlock;
Visit(statement.statement);
statement.statement->Visit(*this);
}
if (node.elseStatement)
@@ -477,7 +480,7 @@ namespace Nz
SpirvBlock elseBlock(m_writer);
m_currentBlock = &elseBlock;
Visit(node.elseStatement);
node.elseStatement->Visit(*this);
elseBlock.Append(SpirvOp::OpBranch, mergeBlock.GetLabelId()); //< FIXME: Shouldn't terminate twice
@@ -496,16 +499,16 @@ namespace Nz
m_currentBlock = &m_blocks.back();
}
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
void SpirvAstVisitor::Visit(ShaderAst::CastExpression& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
const ShaderAst::ShaderExpressionType& targetExprType = node.targetType;
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
StackVector<UInt32> exprResults = NazaraStackVector(UInt32, node.expressions.size());
for (const auto& exprPtr : node.expressions)
for (auto& exprPtr : node.expressions)
{
if (!exprPtr)
break;
@@ -527,21 +530,21 @@ namespace Nz
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalExpression& node)
void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
Visit(node.truePath);
node.truePath->Visit(*this);
else
Visit(node.falsePath);
node.falsePath->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderNodes::ConditionalStatement& node)
void SpirvAstVisitor::Visit(ShaderAst::ConditionalStatement& node)
{
if (m_writer.IsConditionEnabled(node.conditionName))
Visit(node.statement);
node.statement->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderNodes::Constant& node)
void SpirvAstVisitor::Visit(ShaderAst::ConstantExpression& node)
{
std::visit([&] (const auto& value)
{
@@ -549,46 +552,42 @@ namespace Nz
}, node.value);
}
void SpirvAstVisitor::Visit(ShaderNodes::DeclareVariable& node)
void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node)
{
if (node.expression)
{
assert(node.variable->GetType() == ShaderNodes::VariableType::LocalVariable);
const auto& localVar = static_cast<const ShaderNodes::LocalVariable&>(*node.variable);
m_writer.WriteLocalVariable(localVar.name, EvaluateExpression(node.expression));
}
if (node.initialExpression)
m_writer.WriteLocalVariable(node.varName, EvaluateExpression(node.initialExpression));
}
void SpirvAstVisitor::Visit(ShaderNodes::Discard& /*node*/)
void SpirvAstVisitor::Visit(ShaderAst::DiscardStatement& /*node*/)
{
m_currentBlock->Append(SpirvOp::OpKill);
}
void SpirvAstVisitor::Visit(ShaderNodes::ExpressionStatement& node)
void SpirvAstVisitor::Visit(ShaderAst::ExpressionStatement& node)
{
Visit(node.expression);
node.expression->Visit(*this);
PopResultId();
}
void SpirvAstVisitor::Visit(ShaderNodes::Identifier& node)
void SpirvAstVisitor::Visit(ShaderAst::IdentifierExpression& node)
{
SpirvExpressionLoad loadVisitor(m_writer, *m_currentBlock);
PushResultId(loadVisitor.Evaluate(node));
}
void SpirvAstVisitor::Visit(ShaderNodes::IntrinsicCall& node)
void SpirvAstVisitor::Visit(ShaderAst::IntrinsicExpression& node)
{
switch (node.intrinsic)
{
case ShaderNodes::IntrinsicType::DotProduct:
case ShaderAst::IntrinsicType::DotProduct:
{
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
const ShaderAst::ShaderExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsBasicType(vecExprType));
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
ShaderAst::BasicType vecType = std::get<ShaderAst::BasicType>(vecExprType);
UInt32 typeId = m_writer.GetTypeId(node.GetComponentType(vecType));
UInt32 typeId = m_writer.GetTypeId(ShaderAst::GetComponentType(vecType));
UInt32 vec1 = EvaluateExpression(node.parameters[0]);
UInt32 vec2 = EvaluateExpression(node.parameters[1]);
@@ -600,18 +599,18 @@ namespace Nz
break;
}
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderAst::IntrinsicType::CrossProduct:
default:
throw std::runtime_error("not yet implemented");
}
}
void SpirvAstVisitor::Visit(ShaderNodes::NoOp& /*node*/)
void SpirvAstVisitor::Visit(ShaderAst::NoOpStatement& /*node*/)
{
// nothing to do
}
void SpirvAstVisitor::Visit(ShaderNodes::ReturnStatement& node)
void SpirvAstVisitor::Visit(ShaderAst::ReturnStatement& node)
{
if (node.returnExpr)
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
@@ -619,30 +618,18 @@ namespace Nz
m_currentBlock->Append(SpirvOp::OpReturn);
}
void SpirvAstVisitor::Visit(ShaderNodes::Sample2D& node)
{
UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4);
UInt32 samplerId = EvaluateExpression(node.sampler);
UInt32 coordinatesId = EvaluateExpression(node.coordinates);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpImageSampleImplicitLod, typeId, resultId, samplerId, coordinatesId);
PushResultId(resultId);
}
void SpirvAstVisitor::Visit(ShaderNodes::StatementBlock& node)
void SpirvAstVisitor::Visit(ShaderAst::MultiStatement& node)
{
for (auto& statement : node.statements)
Visit(statement);
statement->Visit(*this);
}
void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node)
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
const ShaderExpressionType& targetExprType = node.GetExpressionType();
const ShaderAst::ShaderExpressionType& targetExprType = ShaderAst::GetExpressionType(node);
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
ShaderAst::BasicType targetType = std::get<ShaderAst::BasicType>(targetExprType);
UInt32 exprResultId = EvaluateExpression(node.expression);
UInt32 resultId = m_writer.AllocateResultId();
@@ -666,7 +653,7 @@ namespace Nz
// Extract a single component from the vector
assert(node.componentCount == 1);
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderNodes::SwizzleComponent::First) );
m_currentBlock->Append(SpirvOp::OpCompositeExtract, m_writer.GetTypeId(targetType), resultId, exprResultId, UInt32(node.components[0]) - UInt32(ShaderAst::SwizzleComponent::First) );
}
PushResultId(resultId);

View File

@@ -3,7 +3,6 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvConstantCache.hpp>
#include <Nazara/Shader/ShaderAst.hpp>
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Utility/FieldOffsets.hpp>
#include <tsl/ordered_map.h>
@@ -536,7 +535,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector2f> || std::is_same_v<T, Vector2i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector2f>) ? ShaderNodes::BasicType::Float2 : ShaderNodes::BasicType::Int2),
BuildType((std::is_same_v<T, Vector2f>) ? ShaderAst::BasicType::Float2 : ShaderAst::BasicType::Int2),
{
BuildConstant(arg.x),
BuildConstant(arg.y)
@@ -546,7 +545,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector3f> || std::is_same_v<T, Vector3i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector3f>) ? ShaderNodes::BasicType::Float3 : ShaderNodes::BasicType::Int3),
BuildType((std::is_same_v<T, Vector3f>) ? ShaderAst::BasicType::Float3 : ShaderAst::BasicType::Int3),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@@ -557,7 +556,7 @@ namespace Nz
else if constexpr (std::is_same_v<T, Vector4f> || std::is_same_v<T, Vector4i>)
{
return ConstantComposite{
BuildType((std::is_same_v<T, Vector4f>) ? ShaderNodes::BasicType::Float4 : ShaderNodes::BasicType::Int4),
BuildType((std::is_same_v<T, Vector4f>) ? ShaderAst::BasicType::Float4 : ShaderAst::BasicType::Int4),
{
BuildConstant(arg.x),
BuildConstant(arg.y),
@@ -571,7 +570,7 @@ namespace Nz
}, value));
}
auto SpirvConstantCache::BuildPointerType(const ShaderNodes::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::BasicType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(SpirvConstantCache::Pointer{
SpirvConstantCache::BuildType(type),
@@ -579,55 +578,55 @@ namespace Nz
});
}
auto SpirvConstantCache::BuildPointerType(const ShaderAst& shader, const ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
auto SpirvConstantCache::BuildPointerType(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) -> TypePtr
{
return std::make_shared<Type>(SpirvConstantCache::Pointer{
SpirvConstantCache::BuildType(shader, type),
SpirvConstantCache::BuildType(type),
storageClass
});
}
auto SpirvConstantCache::BuildType(const ShaderNodes::BasicType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::BasicType& type) -> TypePtr
{
return std::make_shared<Type>([&]() -> AnyType
{
switch (type)
{
case ShaderNodes::BasicType::Boolean:
case ShaderAst::BasicType::Boolean:
return Bool{};
case ShaderNodes::BasicType::Float1:
case ShaderAst::BasicType::Float1:
return Float{ 32 };
case ShaderNodes::BasicType::Int1:
case ShaderAst::BasicType::Int1:
return Integer{ 32, true };
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
case ShaderNodes::BasicType::Float4:
case ShaderNodes::BasicType::Int2:
case ShaderNodes::BasicType::Int3:
case ShaderNodes::BasicType::Int4:
case ShaderNodes::BasicType::UInt2:
case ShaderNodes::BasicType::UInt3:
case ShaderNodes::BasicType::UInt4:
case ShaderAst::BasicType::Float2:
case ShaderAst::BasicType::Float3:
case ShaderAst::BasicType::Float4:
case ShaderAst::BasicType::Int2:
case ShaderAst::BasicType::Int3:
case ShaderAst::BasicType::Int4:
case ShaderAst::BasicType::UInt2:
case ShaderAst::BasicType::UInt3:
case ShaderAst::BasicType::UInt4:
{
auto vecType = BuildType(ShaderNodes::Node::GetComponentType(type));
UInt32 componentCount = ShaderNodes::Node::GetComponentCount(type);
auto vecType = BuildType(ShaderAst::GetComponentType(type));
UInt32 componentCount = ShaderAst::GetComponentCount(type);
return Vector{ vecType, componentCount };
}
case ShaderNodes::BasicType::Mat4x4:
return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u };
case ShaderAst::BasicType::Mat4x4:
return Matrix{ BuildType(ShaderAst::BasicType::Float4), 4u };
case ShaderNodes::BasicType::UInt1:
case ShaderAst::BasicType::UInt1:
return Integer{ 32, false };
case ShaderNodes::BasicType::Void:
case ShaderAst::BasicType::Void:
return Void{};
case ShaderNodes::BasicType::Sampler2D:
case ShaderAst::BasicType::Sampler2D:
{
auto imageType = Image{
{}, //< qualifier
@@ -635,7 +634,7 @@ namespace Nz
{}, //< sampled
SpirvDim::Dim2D, //< dim
SpirvImageFormat::Unknown, //< format
BuildType(ShaderNodes::BasicType::Float1), //< sampledType
BuildType(ShaderAst::BasicType::Float1), //< sampledType
false, //< arrayed,
false //< multisampled
};
@@ -648,16 +647,16 @@ namespace Nz
}());
}
auto SpirvConstantCache::BuildType(const ShaderAst& shader, const ShaderExpressionType& type) -> TypePtr
auto SpirvConstantCache::BuildType(const ShaderAst::ShaderExpressionType& type) -> TypePtr
{
return std::visit([&](auto&& arg) -> TypePtr
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
if constexpr (std::is_same_v<T, ShaderAst::BasicType>)
return BuildType(arg);
else if constexpr (std::is_same_v<T, std::string>)
{
// Register struct members type
/*// Register struct members type
const auto& structs = shader.GetStructs();
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == arg; });
if (it == structs.end())
@@ -675,7 +674,8 @@ namespace Nz
sMembers.type = BuildType(shader, member.type);
}
return std::make_shared<Type>(std::move(sType));
return std::make_shared<Type>(std::move(sType));*/
return nullptr;
}
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");

View File

@@ -16,7 +16,7 @@ namespace Nz
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
}
UInt32 SpirvExpressionLoad::Evaluate(ShaderNodes::Expression& node)
UInt32 SpirvExpressionLoad::Evaluate(ShaderAst::Expression& node)
{
node.Visit(*this);
@@ -41,7 +41,7 @@ namespace Nz
}, m_value);
}
void SpirvExpressionLoad::Visit(ShaderNodes::AccessMember& node)
/*void SpirvExpressionLoad::Visit(ShaderAst::AccessMemberExpression& node)
{
Visit(node.structExpr);
@@ -49,6 +49,8 @@ namespace Nz
{
[&](const Pointer& pointer)
{
ShaderAst::ShaderExpressionType exprType = GetExpressionType(node.structExpr);
UInt32 resultId = m_writer.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(node.exprType, pointer.storage); //< FIXME
UInt32 typeId = m_writer.GetTypeId(node.exprType);
@@ -87,40 +89,15 @@ namespace Nz
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}
}*/
void SpirvExpressionLoad::Visit(ShaderNodes::Identifier& node)
void SpirvExpressionLoad::Visit(ShaderAst::IdentifierExpression& node)
{
Visit(node.var);
}
void SpirvExpressionLoad::Visit(ShaderNodes::InputVariable& var)
{
auto inputVar = m_writer.GetInputVariable(var.name);
if (auto resultIdOpt = m_writer.ReadVariable(inputVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
if (node.identifier == "d")
m_value = Value{ m_writer.ReadLocalVariable(node.identifier) };
else
m_value = Pointer{ SpirvStorageClass::Input, inputVar.varId, inputVar.typeId };
}
m_value = Value{ m_writer.ReadParameterVariable(node.identifier) };
void SpirvExpressionLoad::Visit(ShaderNodes::LocalVariable& var)
{
m_value = Value{ m_writer.ReadLocalVariable(var.name) };
}
void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& var)
{
m_value = Value{ m_writer.ReadParameterVariable(var.name) };
}
void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var)
{
auto uniformVar = m_writer.GetUniformVariable(var.name);
if (auto resultIdOpt = m_writer.ReadVariable(uniformVar, SpirvWriter::OnlyCache{}))
m_value = Value{ *resultIdOpt };
else
m_value = Pointer{ SpirvStorageClass::Uniform, uniformVar.varId, uniformVar.typeId };
//Visit(node.var);
}
}

View File

@@ -15,9 +15,9 @@ namespace Nz
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
}
void SpirvExpressionStore::Store(const ShaderNodes::ExpressionPtr& node, UInt32 resultId)
void SpirvExpressionStore::Store(ShaderAst::ExpressionPtr& node, UInt32 resultId)
{
Visit(node);
node->Visit(*this);
std::visit(overloaded
{
@@ -36,7 +36,7 @@ namespace Nz
}, m_value);
}
void SpirvExpressionStore::Visit(ShaderNodes::AccessMember& node)
/*void SpirvExpressionStore::Visit(ShaderAst::AccessMemberExpression& node)
{
Visit(node.structExpr);
@@ -70,34 +70,15 @@ namespace Nz
throw std::runtime_error("an internal error occurred");
}
}, m_value);
}
}*/
void SpirvExpressionStore::Visit(ShaderNodes::Identifier& node)
void SpirvExpressionStore::Visit(ShaderAst::IdentifierExpression& node)
{
Visit(node.var);
m_value = LocalVar{ node.identifier };
}
void SpirvExpressionStore::Visit(ShaderNodes::SwizzleOp& node)
void SpirvExpressionStore::Visit(ShaderAst::SwizzleExpression& node)
{
throw std::runtime_error("not yet implemented");
}
void SpirvExpressionStore::Visit(ShaderNodes::BuiltinVariable& var)
{
const auto& outputVar = m_writer.GetBuiltinVariable(var.entry);
m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId };
}
void SpirvExpressionStore::Visit(ShaderNodes::LocalVariable& var)
{
m_value = LocalVar{ var.name };
}
void SpirvExpressionStore::Visit(ShaderNodes::OutputVariable& var)
{
const auto& outputVar = m_writer.GetOutputVariable(var.name);
m_value = Pointer{ SpirvStorageClass::Output, outputVar.varId };
}
}

View File

@@ -26,155 +26,131 @@ namespace Nz
{
namespace
{
class PreVisitor : public ShaderAstRecursiveVisitor, public ShaderVarVisitor
class PreVisitor : public ShaderAst::AstRecursiveVisitor
{
public:
using BuiltinContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::BuiltinVariable>>;
using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<std::shared_ptr<const ShaderNodes::LocalVariable>>;
using ParameterContainer = std::unordered_set< std::shared_ptr<const ShaderNodes::ParameterVariable>>;
using LocalContainer = std::unordered_set<ShaderAst::ShaderExpressionType>;
PreVisitor(const ShaderAst& shader, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_shader(shader),
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_cache(cache),
m_conditions(conditions),
m_constantCache(constantCache)
{
}
using ShaderAstRecursiveVisitor::Visit;
using ShaderVarVisitor::Visit;
void Visit(ShaderNodes::AccessMember& node) override
void Visit(ShaderAst::AccessMemberExpression& node) override
{
for (std::size_t index : node.memberIndices)
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));
/*for (std::size_t index : node.memberIdentifiers)
m_constantCache.Register(*SpirvConstantCache::BuildConstant(Int32(index)));*/
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderNodes::ConditionalExpression& node) override
void Visit(ShaderAst::ConditionalExpression& node) override
{
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
Visit(node.truePath);
else
Visit(node.falsePath);
Visit(node.falsePath);*/
}
void Visit(ShaderNodes::ConditionalStatement& node) override
void Visit(ShaderAst::ConditionalStatement& node) override
{
std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
/*std::size_t conditionIndex = m_shader.FindConditionByName(node.conditionName);
assert(conditionIndex != ShaderAst::InvalidCondition);
if (TestBit<Nz::UInt64>(m_conditions.enabledConditions, conditionIndex))
Visit(node.statement);
Visit(node.statement);*/
}
void Visit(ShaderNodes::Constant& node) override
void Visit(ShaderAst::ConstantExpression& node) override
{
std::visit([&](auto&& arg)
{
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
}, node.value);
ShaderAstRecursiveVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderNodes::DeclareVariable& node) override
void Visit(ShaderAst::DeclareFunctionStatement& node) override
{
Visit(node.variable);
ShaderAstRecursiveVisitor::Visit(node);
m_constantCache.Register(*SpirvConstantCache::BuildType(node.returnType));
for (auto& parameter : node.parameters)
m_constantCache.Register(*SpirvConstantCache::BuildType(parameter.type));
}
void Visit(ShaderNodes::Identifier& node) override
void Visit(ShaderAst::DeclareStructStatement& node) override
{
Visit(node.var);
ShaderAstRecursiveVisitor::Visit(node);
for (auto& field : node.description.members)
m_constantCache.Register(*SpirvConstantCache::BuildType(field.type));
}
void Visit(ShaderNodes::IntrinsicCall& node) override
void Visit(ShaderAst::DeclareVariableStatement& node) override
{
ShaderAstRecursiveVisitor::Visit(node);
variableTypes.insert(node.varType);
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderAst::IdentifierExpression& node) override
{
variableTypes.insert(GetExpressionType(node, m_cache));
AstRecursiveVisitor::Visit(node);
}
void Visit(ShaderAst::IntrinsicExpression& node) override
{
AstRecursiveVisitor::Visit(node);
switch (node.intrinsic)
{
// Require GLSL.std.450
case ShaderNodes::IntrinsicType::CrossProduct:
case ShaderAst::IntrinsicType::CrossProduct:
extInsts.emplace("GLSL.std.450");
break;
// Part of SPIR-V core
case ShaderNodes::IntrinsicType::DotProduct:
case ShaderAst::IntrinsicType::DotProduct:
break;
}
}
void Visit(ShaderNodes::BuiltinVariable& var) override
{
builtinVars.insert(std::static_pointer_cast<const ShaderNodes::BuiltinVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::InputVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
void Visit(ShaderNodes::LocalVariable& var) override
{
localVars.insert(std::static_pointer_cast<const ShaderNodes::LocalVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::OutputVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
void Visit(ShaderNodes::ParameterVariable& var) override
{
paramVars.insert(std::static_pointer_cast<const ShaderNodes::ParameterVariable>(var.shared_from_this()));
}
void Visit(ShaderNodes::UniformVariable& /*var*/) override
{
/* Handled by ShaderAst */
}
BuiltinContainer builtinVars;
ExtInstList extInsts;
LocalContainer localVars;
ParameterContainer paramVars;
LocalContainer variableTypes;
private:
const ShaderAst& m_shader;
ShaderAst::AstCache* m_cache;
const SpirvWriter::States& m_conditions;
SpirvConstantCache& m_constantCache;
};
template<typename T>
constexpr ShaderNodes::BasicType GetBasicType()
constexpr ShaderAst::BasicType GetBasicType()
{
if constexpr (std::is_same_v<T, bool>)
return ShaderNodes::BasicType::Boolean;
return ShaderAst::BasicType::Boolean;
else if constexpr (std::is_same_v<T, float>)
return(ShaderNodes::BasicType::Float1);
return(ShaderAst::BasicType::Float1);
else if constexpr (std::is_same_v<T, Int32>)
return(ShaderNodes::BasicType::Int1);
return(ShaderAst::BasicType::Int1);
else if constexpr (std::is_same_v<T, Vector2f>)
return(ShaderNodes::BasicType::Float2);
return(ShaderAst::BasicType::Float2);
else if constexpr (std::is_same_v<T, Vector3f>)
return(ShaderNodes::BasicType::Float3);
return(ShaderAst::BasicType::Float3);
else if constexpr (std::is_same_v<T, Vector4f>)
return(ShaderNodes::BasicType::Float4);
return(ShaderAst::BasicType::Float4);
else if constexpr (std::is_same_v<T, Vector2i32>)
return(ShaderNodes::BasicType::Int2);
return(ShaderAst::BasicType::Int2);
else if constexpr (std::is_same_v<T, Vector3i32>)
return(ShaderNodes::BasicType::Int3);
return(ShaderAst::BasicType::Int3);
else if constexpr (std::is_same_v<T, Vector4i32>)
return(ShaderNodes::BasicType::Int4);
return(ShaderAst::BasicType::Int4);
else
static_assert(AlwaysFalse<T>::value, "unhandled type");
}
@@ -198,7 +174,7 @@ namespace Nz
tsl::ordered_map<std::string, ExtVar> parameterIds;
tsl::ordered_map<std::string, ExtVar> uniformIds;
std::unordered_map<std::string, UInt32> extensionInstructions;
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<ShaderAst::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<std::string, UInt32> varToResult;
std::vector<Func> funcs;
std::vector<SpirvBlock> functionBlocks;
@@ -219,13 +195,12 @@ namespace Nz
{
}
std::vector<UInt32> SpirvWriter::Generate(const ShaderAst& shader, const States& conditions)
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
{
std::string error;
if (!ValidateShader(shader, &error))
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.shader = &shader;
m_context.states = &conditions;
State state;
@@ -235,23 +210,19 @@ namespace Nz
m_currentState = nullptr;
});
std::vector<ShaderNodes::StatementPtr> functionStatements;
std::vector<ShaderAst::StatementPtr> functionStatements;
ShaderAstCloner cloner;
PreVisitor preVisitor(shader, conditions, state.constantTypeCache);
for (const auto& func : shader.GetFunctions())
{
functionStatements.emplace_back(cloner.Clone(func.statement));
preVisitor.Visit(func.statement);
}
ShaderAst::AstCloner cloner;
// Register all extended instruction sets
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
shader->Visit(preVisitor);
for (const std::string& extInst : preVisitor.extInsts)
state.extensionInstructions[extInst] = AllocateResultId();
// Register all types
for (const auto& func : shader.GetFunctions())
/*for (const auto& func : shader.GetFunctions())
{
RegisterType(func.returnType);
for (const auto& param : func.parameters)
@@ -270,8 +241,8 @@ namespace Nz
for (const auto& func : shader.GetFunctions())
RegisterFunctionType(func.returnType, func.parameters);
for (const auto& local : preVisitor.localVars)
RegisterType(local->type);
for (const auto& type : preVisitor.variableTypes)
RegisterType(type);
for (const auto& builtin : preVisitor.builtinVars)
RegisterType(builtin->type);
@@ -283,7 +254,7 @@ namespace Nz
SpirvBuiltIn builtinDecoration;
switch (builtin->entry)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
case ShaderAst::BuiltinEntry::VertexPosition:
variable.debugName = "builtin_VertexPosition";
variable.storageClass = SpirvStorageClass::Output;
@@ -294,10 +265,10 @@ namespace Nz
throw std::runtime_error("unexpected builtin type");
}
const ShaderExpressionType& builtinExprType = builtin->type;
const ShaderAst::ShaderExpressionType& builtinExprType = builtin->type;
assert(IsBasicType(builtinExprType));
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);
ShaderAst::BasicType builtinType = std::get<ShaderAst::BasicType>(builtinExprType);
variable.type = SpirvConstantCache::BuildPointerType(builtinType, variable.storageClass);
@@ -420,7 +391,7 @@ namespace Nz
if (!state.functionBlocks.back().IsTerminated())
{
assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void));
assert(func.returnType == ShaderAst::ShaderExpressionType(ShaderAst::BasicType::Void));
state.functionBlocks.back().Append(SpirvOp::OpReturn);
}
@@ -475,14 +446,14 @@ namespace Nz
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
}
}*/
std::vector<UInt32> ret;
MergeSections(ret, state.header);
/*MergeSections(ret, state.header);
MergeSections(ret, state.debugInfo);
MergeSections(ret, state.annotations);
MergeSections(ret, state.constants);
MergeSections(ret, state.instructions);
MergeSections(ret, state.instructions);*/
return ret;
}
@@ -516,16 +487,16 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
}
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function));
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(parameter.type, SpirvStorageClass::Function));
return SpirvConstantCache::Function{
SpirvConstantCache::BuildType(*m_context.shader, retType),
SpirvConstantCache::BuildType(retType),
std::move(parameterTypes)
};
}
@@ -535,12 +506,12 @@ namespace Nz
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
UInt32 SpirvWriter::GetFunctionTypeId(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
{
return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) });
}
auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar&
auto SpirvWriter::GetBuiltinVariable(ShaderAst::BuiltinEntry builtin) const -> const ExtVar&
{
auto it = m_currentState->builtinIds.find(builtin);
assert(it != m_currentState->builtinIds.end());
@@ -572,14 +543,14 @@ namespace Nz
return it.value();
}
UInt32 SpirvWriter::GetPointerTypeId(const ShaderExpressionType& type, SpirvStorageClass storageClass) const
UInt32 SpirvWriter::GetPointerTypeId(const ShaderAst::ShaderExpressionType& type, SpirvStorageClass storageClass) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::GetTypeId(const ShaderExpressionType& type) const
UInt32 SpirvWriter::GetTypeId(const ShaderAst::ShaderExpressionType& type) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(*m_context.shader, type));
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildType(type));
}
UInt32 SpirvWriter::ReadInputVariable(const std::string& name)
@@ -673,20 +644,20 @@ namespace Nz
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildConstant(value));
}
UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
UInt32 SpirvWriter::RegisterFunctionType(ShaderAst::ShaderExpressionType retType, const std::vector<FunctionParameter>& parameters)
{
return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)
UInt32 SpirvWriter::RegisterPointerType(ShaderAst::ShaderExpressionType type, SpirvStorageClass storageClass)
{
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(*m_context.shader, type, storageClass));
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildPointerType(type, storageClass));
}
UInt32 SpirvWriter::RegisterType(ShaderExpressionType type)
UInt32 SpirvWriter::RegisterType(ShaderAst::ShaderExpressionType type)
{
assert(m_currentState);
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(*m_context.shader, type));
return m_currentState->constantTypeCache.Register(*SpirvConstantCache::BuildType(type));
}
void SpirvWriter::WriteLocalVariable(std::string name, UInt32 resultId)