Shader: Add support for partial sanitization
This commit is contained in:
@@ -249,7 +249,7 @@ struct VertIn
|
||||
}
|
||||
|
||||
[entry(vert), cond(Billboard)]
|
||||
fn billboardMain(input: VertIn) -> VertOut
|
||||
fn billboardMain(input: VertIn) -> VertToFrag
|
||||
{
|
||||
let size = input.billboardSizeRot.xy;
|
||||
let sinCos = input.billboardSizeRot.zw;
|
||||
|
||||
@@ -25,7 +25,7 @@ namespace Nz
|
||||
m_shaderModule = moduleResolver.Resolve(moduleName);
|
||||
NazaraAssert(m_shaderModule, "invalid shader module");
|
||||
|
||||
Validate(*m_shaderModule);
|
||||
m_shaderModule = Validate(*m_shaderModule, &m_optionIndexByName);
|
||||
|
||||
m_onShaderModuleUpdated.Connect(moduleResolver.OnModuleUpdated, [this, name = std::move(moduleName)](ShaderModuleResolver* resolver, const std::string& updatedModuleName)
|
||||
{
|
||||
@@ -41,8 +41,7 @@ namespace Nz
|
||||
|
||||
try
|
||||
{
|
||||
// FIXME: Validate is destructive, in case of failure it can invalidate the shader
|
||||
Validate(*newShaderModule);
|
||||
m_shaderModule = Validate(*newShaderModule, &m_optionIndexByName);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
@@ -50,8 +49,6 @@ namespace Nz
|
||||
return;
|
||||
}
|
||||
|
||||
m_shaderModule = std::move(newShaderModule);
|
||||
|
||||
// Clear cache
|
||||
m_combinations.clear();
|
||||
|
||||
@@ -65,7 +62,7 @@ namespace Nz
|
||||
{
|
||||
NazaraAssert(m_shaderModule, "invalid shader module");
|
||||
|
||||
Validate(*m_shaderModule);
|
||||
Validate(*m_shaderModule, &m_optionIndexByName);
|
||||
}
|
||||
|
||||
const std::shared_ptr<ShaderModule>& UberShader::Get(const Config& config)
|
||||
@@ -85,13 +82,17 @@ namespace Nz
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void UberShader::Validate(ShaderAst::Module& module)
|
||||
ShaderAst::ModulePtr UberShader::Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options)
|
||||
{
|
||||
NazaraAssert(m_shaderStages != 0, "there must be at least one shader stage");
|
||||
assert(options);
|
||||
|
||||
//TODO: Try to partially sanitize shader?
|
||||
// Try to partially sanitize shader
|
||||
|
||||
std::size_t optionCount = 0;
|
||||
ShaderAst::SanitizeVisitor::Options sanitizeOptions;
|
||||
sanitizeOptions.allowPartialSanitization = true;
|
||||
|
||||
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module, sanitizeOptions);
|
||||
|
||||
ShaderStageTypeFlags supportedStageType;
|
||||
|
||||
@@ -101,21 +102,24 @@ namespace Nz
|
||||
supportedStageType |= stageType;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, Option> optionByName;
|
||||
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
|
||||
{
|
||||
//TODO: Check optionType
|
||||
|
||||
m_optionIndexByName[option.optName] = Option{
|
||||
optionByName[option.optName] = Option{
|
||||
CRC32(option.optName)
|
||||
};
|
||||
|
||||
optionCount++;
|
||||
};
|
||||
|
||||
ShaderAst::AstReflect reflect;
|
||||
reflect.Reflect(*module.rootNode, callbacks);
|
||||
reflect.Reflect(*sanitizedModule->rootNode, callbacks);
|
||||
|
||||
if ((m_shaderStages & supportedStageType) != m_shaderStages)
|
||||
throw std::runtime_error("shader doesn't support all required shader stages");
|
||||
|
||||
*options = std::move(optionByName);
|
||||
|
||||
return sanitizedModule;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -481,6 +481,16 @@ namespace Nz::ShaderAst
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(TypeExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<TypeExpression>();
|
||||
clone->typeId = node.typeId;
|
||||
|
||||
clone->cachedExpressionType = node.cachedExpressionType;
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
ExpressionPtr AstCloner::Clone(VariableValueExpression& node)
|
||||
{
|
||||
auto clone = std::make_unique<VariableValueExpression>();
|
||||
|
||||
@@ -862,7 +862,7 @@ namespace Nz::ShaderAst
|
||||
|
||||
const auto& constantExpr = static_cast<ConstantValueExpression&>(*expressions[i]);
|
||||
|
||||
if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value))
|
||||
if (!constantValues.empty() && GetConstantType(constantValues.front()) != GetConstantType(constantExpr.value))
|
||||
{
|
||||
// Unhandled case, all cast parameters are expected to be of the same type
|
||||
constantValues.clear();
|
||||
@@ -940,16 +940,24 @@ namespace Nz::ShaderAst
|
||||
std::vector<BranchStatement::ConditionalStatement> statements;
|
||||
StatementPtr elseStatement;
|
||||
|
||||
bool continuePropagation = true;
|
||||
for (auto& condStatement : node.condStatements)
|
||||
{
|
||||
auto cond = CloneExpression(condStatement.condition);
|
||||
|
||||
if (cond->GetType() == NodeType::ConstantValueExpression)
|
||||
if (continuePropagation && cond->GetType() == NodeType::ConstantValueExpression)
|
||||
{
|
||||
auto& constant = static_cast<ConstantValueExpression&>(*cond);
|
||||
|
||||
const ExpressionType& constantType = GetExpressionType(constant);
|
||||
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
|
||||
const ExpressionType* constantType = GetExpressionType(constant);
|
||||
if (!constantType)
|
||||
{
|
||||
// unresolved type, can't continue propagating this branch
|
||||
continuePropagation = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsPrimitiveType(*constantType) || std::get<PrimitiveType>(*constantType) != PrimitiveType::Boolean)
|
||||
continue;
|
||||
|
||||
bool cValue = std::get<bool>(constant.value);
|
||||
@@ -1017,8 +1025,12 @@ namespace Nz::ShaderAst
|
||||
if (!m_options.constantQueryCallback)
|
||||
return AstCloner::Clone(node);
|
||||
|
||||
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
|
||||
constant->cachedExpressionType = GetExpressionType(constant->value);
|
||||
const ConstantValue* constantValue = m_options.constantQueryCallback(node.constantId);
|
||||
if (!constantValue)
|
||||
return AstCloner::Clone(node);
|
||||
|
||||
auto constant = ShaderBuilder::Constant(*constantValue);
|
||||
constant->cachedExpressionType = GetConstantType(constant->value);
|
||||
|
||||
return constant;
|
||||
}
|
||||
@@ -1155,7 +1167,7 @@ namespace Nz::ShaderAst
|
||||
}, lhs.value);
|
||||
|
||||
if (optimized)
|
||||
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
||||
optimized->cachedExpressionType = GetConstantType(optimized->value);
|
||||
|
||||
return optimized;
|
||||
}
|
||||
@@ -1221,7 +1233,7 @@ namespace Nz::ShaderAst
|
||||
}, operand.value);
|
||||
|
||||
if (optimized)
|
||||
optimized->cachedExpressionType = GetExpressionType(optimized->value);
|
||||
optimized->cachedExpressionType = GetConstantType(optimized->value);
|
||||
|
||||
return optimized;
|
||||
}
|
||||
|
||||
@@ -109,6 +109,11 @@ namespace Nz::ShaderAst
|
||||
node.expression->Visit(*this);
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(TypeExpression& node)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void AstRecursiveVisitor::Visit(VariableValueExpression& /*node*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
|
||||
@@ -174,6 +174,11 @@ namespace Nz::ShaderAst
|
||||
SizeT(node.structTypeId);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(TypeExpression& node)
|
||||
{
|
||||
SizeT(node.typeId);
|
||||
}
|
||||
|
||||
void AstSerializerBase::Serialize(FunctionExpression& node)
|
||||
{
|
||||
SizeT(node.funcId);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
||||
#include <cassert>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
@@ -104,7 +105,10 @@ namespace Nz::ShaderAst
|
||||
|
||||
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
|
||||
{
|
||||
if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1)
|
||||
const ExpressionType* exprType = GetExpressionType(node);
|
||||
assert(exprType);
|
||||
|
||||
if (IsPrimitiveType(*exprType) && node.componentCount > 1)
|
||||
// Swizzling more than a component on a primitive produces a rvalue (a.xxxx cannot be assigned)
|
||||
m_expressionCategory = ExpressionCategory::RValue;
|
||||
else
|
||||
@@ -133,6 +137,11 @@ namespace Nz::ShaderAst
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(TypeExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::LValue;
|
||||
}
|
||||
|
||||
void ShaderAstValueCategory::Visit(VariableValueExpression& /*node*/)
|
||||
{
|
||||
m_expressionCategory = ExpressionCategory::LValue;
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
namespace Nz::ShaderAst
|
||||
{
|
||||
ExpressionType GetExpressionType(const ConstantValue& constant)
|
||||
ExpressionType GetConstantType(const ConstantValue& constant)
|
||||
{
|
||||
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
|
||||
{
|
||||
|
||||
@@ -15,10 +15,11 @@ namespace Nz::ShaderAst
|
||||
|
||||
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
|
||||
{
|
||||
const auto& targetFuncType = GetExpressionType(*node.targetFunction);
|
||||
assert(std::holds_alternative<FunctionType>(targetFuncType));
|
||||
const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction);
|
||||
assert(targetFuncType);
|
||||
assert(std::holds_alternative<FunctionType>(*targetFuncType));
|
||||
|
||||
const auto& funcType = std::get<FunctionType>(targetFuncType);
|
||||
const auto& funcType = std::get<FunctionType>(*targetFuncType);
|
||||
|
||||
assert(m_currentFunctionIndex);
|
||||
if (m_currentVariableDeclIndex)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -36,7 +36,7 @@ namespace Nz
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
assert(currentFunction);
|
||||
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex);
|
||||
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::ConditionalExpression& /*node*/) override
|
||||
@@ -307,7 +307,7 @@ namespace Nz
|
||||
Append(type.GetResultingValue());
|
||||
}
|
||||
|
||||
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||
{
|
||||
throw std::runtime_error("unexpected FunctionType");
|
||||
}
|
||||
@@ -829,8 +829,9 @@ namespace Nz
|
||||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
||||
assert(IsStructType(exprType));
|
||||
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
|
||||
assert(exprType);
|
||||
assert(IsStructType(*exprType));
|
||||
|
||||
for (const std::string& identifier : node.identifiers)
|
||||
Append(".", identifier);
|
||||
@@ -840,8 +841,9 @@ namespace Nz
|
||||
{
|
||||
Visit(node.expr, true);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
|
||||
assert(!IsStructType(exprType));
|
||||
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
|
||||
assert(exprType);
|
||||
assert(!IsStructType(*exprType));
|
||||
|
||||
// Array access
|
||||
assert(node.indices.size() == 1);
|
||||
@@ -1326,9 +1328,10 @@ namespace Nz
|
||||
{
|
||||
assert(node.returnExpr);
|
||||
|
||||
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
|
||||
assert(IsStructType(returnType));
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
|
||||
const ShaderAst::ExpressionType* returnType = GetExpressionType(*node.returnExpr);
|
||||
assert(returnType);
|
||||
assert(IsStructType(*returnType));
|
||||
std::size_t structIndex = std::get<ShaderAst::StructType>(*returnType).structIndex;
|
||||
const auto& structData = Retrieve(m_currentState->structs, structIndex);
|
||||
|
||||
std::string outputStructVarName;
|
||||
|
||||
@@ -182,7 +182,7 @@ namespace Nz
|
||||
type.GetExpression()->Visit(*this);
|
||||
}
|
||||
|
||||
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
|
||||
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
|
||||
{
|
||||
throw std::runtime_error("unexpected function type");
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ namespace Nz::ShaderLang
|
||||
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
|
||||
|
||||
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
|
||||
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
||||
if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
||||
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
|
||||
|
||||
const std::string& versionStr = std::get<std::string>(constantValue.value);
|
||||
@@ -302,7 +302,7 @@ namespace Nz::ShaderLang
|
||||
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
|
||||
|
||||
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
|
||||
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
||||
if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
|
||||
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
|
||||
|
||||
const std::string& uuidStr = std::get<std::string>(constantValue.value);
|
||||
|
||||
@@ -67,9 +67,9 @@ namespace Nz
|
||||
throw std::runtime_error("unexpected type");
|
||||
};
|
||||
|
||||
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left);
|
||||
const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right);
|
||||
const ShaderAst::ExpressionType& resultType = *GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType& leftType = *GetExpressionType(*node.left);
|
||||
const ShaderAst::ExpressionType& rightType = *GetExpressionType(*node.right);
|
||||
|
||||
ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType);
|
||||
//ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType);
|
||||
@@ -405,7 +405,7 @@ namespace Nz
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
|
||||
{
|
||||
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
|
||||
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex;
|
||||
|
||||
UInt32 funcId = 0;
|
||||
for (const auto& [funcIndex, func] : m_funcData)
|
||||
@@ -434,7 +434,7 @@ namespace Nz
|
||||
UInt32 resultId = AllocateResultId();
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender)
|
||||
{
|
||||
appender(m_writer.GetTypeId(ShaderAst::GetExpressionType(node)));
|
||||
appender(m_writer.GetTypeId(*ShaderAst::GetExpressionType(node)));
|
||||
appender(resultId);
|
||||
appender(funcId);
|
||||
|
||||
@@ -718,9 +718,11 @@ namespace Nz
|
||||
{
|
||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||
|
||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsVectorType(parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
||||
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(parameterType);
|
||||
assert(IsVectorType(*parameterType));
|
||||
|
||||
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||
|
||||
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
||||
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
||||
@@ -733,10 +735,11 @@ namespace Nz
|
||||
|
||||
case ShaderAst::IntrinsicType::DotProduct:
|
||||
{
|
||||
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsVectorType(vecExprType));
|
||||
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(vecExprType);
|
||||
assert(IsVectorType(*vecExprType));
|
||||
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
|
||||
|
||||
UInt32 typeId = m_writer.GetTypeId(vecType.type);
|
||||
|
||||
@@ -754,9 +757,10 @@ namespace Nz
|
||||
{
|
||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||
|
||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
||||
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(parameterType);
|
||||
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||
|
||||
UInt32 param = EvaluateExpression(node.parameters[0]);
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
@@ -770,10 +774,11 @@ namespace Nz
|
||||
{
|
||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||
|
||||
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsVectorType(vecExprType));
|
||||
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(vecExprType);
|
||||
assert(IsVectorType(*vecExprType));
|
||||
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
|
||||
UInt32 typeId = m_writer.GetTypeId(vecType.type);
|
||||
|
||||
UInt32 vec = EvaluateExpression(node.parameters[0]);
|
||||
@@ -790,15 +795,16 @@ namespace Nz
|
||||
{
|
||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||
|
||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
||||
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(parameterType);
|
||||
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||
|
||||
ShaderAst::PrimitiveType basicType;
|
||||
if (IsPrimitiveType(parameterType))
|
||||
basicType = std::get<ShaderAst::PrimitiveType>(parameterType);
|
||||
else if (IsVectorType(parameterType))
|
||||
basicType = std::get<ShaderAst::VectorType>(parameterType).type;
|
||||
if (IsPrimitiveType(*parameterType))
|
||||
basicType = std::get<ShaderAst::PrimitiveType>(*parameterType);
|
||||
else if (IsVectorType(*parameterType))
|
||||
basicType = std::get<ShaderAst::VectorType>(*parameterType).type;
|
||||
else
|
||||
throw std::runtime_error("unexpected expression type");
|
||||
|
||||
@@ -837,10 +843,11 @@ namespace Nz
|
||||
{
|
||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||
|
||||
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsVectorType(vecExprType));
|
||||
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
|
||||
assert(vecExprType);
|
||||
assert(IsVectorType(*vecExprType));
|
||||
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
|
||||
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
|
||||
UInt32 typeId = m_writer.GetTypeId(vecType);
|
||||
|
||||
UInt32 vec = EvaluateExpression(node.parameters[0]);
|
||||
@@ -856,9 +863,10 @@ namespace Nz
|
||||
{
|
||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||
|
||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
||||
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(parameterType);
|
||||
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||
|
||||
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
||||
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
||||
@@ -873,9 +881,10 @@ namespace Nz
|
||||
{
|
||||
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
|
||||
|
||||
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(IsVectorType(parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(parameterType);
|
||||
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
|
||||
assert(parameterType);
|
||||
assert(IsVectorType(*parameterType));
|
||||
UInt32 typeId = m_writer.GetTypeId(*parameterType);
|
||||
|
||||
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
|
||||
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
|
||||
@@ -951,20 +960,22 @@ namespace Nz
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
|
||||
{
|
||||
const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression);
|
||||
const ShaderAst::ExpressionType* swizzledExpressionType = GetExpressionType(*node.expression);
|
||||
assert(swizzledExpressionType);
|
||||
|
||||
UInt32 exprResultId = EvaluateExpression(node.expression);
|
||||
|
||||
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType* targetExprType = GetExpressionType(node);
|
||||
assert(targetExprType);
|
||||
|
||||
if (node.componentCount > 1)
|
||||
{
|
||||
assert(IsVectorType(targetExprType));
|
||||
assert(IsVectorType(*targetExprType));
|
||||
|
||||
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(targetExprType);
|
||||
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(*targetExprType);
|
||||
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
if (IsVectorType(swizzledExpressionType))
|
||||
if (IsVectorType(*swizzledExpressionType))
|
||||
{
|
||||
// Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
|
||||
@@ -980,7 +991,7 @@ namespace Nz
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(IsPrimitiveType(swizzledExpressionType));
|
||||
assert(IsPrimitiveType(*swizzledExpressionType));
|
||||
|
||||
// Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct
|
||||
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
|
||||
@@ -995,10 +1006,10 @@ namespace Nz
|
||||
|
||||
PushResultId(resultId);
|
||||
}
|
||||
else if (IsVectorType(swizzledExpressionType))
|
||||
else if (IsVectorType(*swizzledExpressionType))
|
||||
{
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
|
||||
assert(IsPrimitiveType(*targetExprType));
|
||||
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(*targetExprType);
|
||||
|
||||
// Extract a single component from the vector
|
||||
assert(node.componentCount == 1);
|
||||
@@ -1011,8 +1022,8 @@ namespace Nz
|
||||
else
|
||||
{
|
||||
// Swizzling a primitive to itself (a.x for example), don't do anything
|
||||
assert(IsPrimitiveType(swizzledExpressionType));
|
||||
assert(IsPrimitiveType(targetExprType));
|
||||
assert(IsPrimitiveType(*swizzledExpressionType));
|
||||
assert(IsPrimitiveType(*targetExprType));
|
||||
assert(node.componentCount == 1);
|
||||
assert(node.components[0] == 0);
|
||||
|
||||
@@ -1022,8 +1033,11 @@ namespace Nz
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
|
||||
{
|
||||
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expression);
|
||||
const ShaderAst::ExpressionType* resultType = GetExpressionType(node);
|
||||
assert(resultType);
|
||||
|
||||
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expression);
|
||||
assert(exprType);
|
||||
|
||||
UInt32 operand = EvaluateExpression(node.expression);
|
||||
|
||||
@@ -1033,11 +1047,11 @@ namespace Nz
|
||||
{
|
||||
case ShaderAst::UnaryType::LogicalNot:
|
||||
{
|
||||
assert(IsPrimitiveType(exprType));
|
||||
assert(std::get<ShaderAst::PrimitiveType>(resultType) == ShaderAst::PrimitiveType::Boolean);
|
||||
assert(IsPrimitiveType(*exprType));
|
||||
assert(std::get<ShaderAst::PrimitiveType>(*resultType) == ShaderAst::PrimitiveType::Boolean);
|
||||
|
||||
UInt32 resultId = m_writer.AllocateResultId();
|
||||
m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(resultType), resultId, operand);
|
||||
m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(*resultType), resultId, operand);
|
||||
|
||||
return resultId;
|
||||
}
|
||||
@@ -1045,10 +1059,10 @@ namespace Nz
|
||||
case ShaderAst::UnaryType::Minus:
|
||||
{
|
||||
ShaderAst::PrimitiveType basicType;
|
||||
if (IsPrimitiveType(exprType))
|
||||
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
|
||||
else if (IsVectorType(exprType))
|
||||
basicType = std::get<ShaderAst::VectorType>(exprType).type;
|
||||
if (IsPrimitiveType(*exprType))
|
||||
basicType = std::get<ShaderAst::PrimitiveType>(*exprType);
|
||||
else if (IsVectorType(*exprType))
|
||||
basicType = std::get<ShaderAst::VectorType>(*exprType).type;
|
||||
else
|
||||
throw std::runtime_error("unexpected expression type");
|
||||
|
||||
@@ -1057,12 +1071,12 @@ namespace Nz
|
||||
switch (basicType)
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Float32:
|
||||
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand);
|
||||
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(*resultType), resultId, operand);
|
||||
return resultId;
|
||||
|
||||
case ShaderAst::PrimitiveType::Int32:
|
||||
case ShaderAst::PrimitiveType::UInt32:
|
||||
m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(resultType), resultId, operand);
|
||||
m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(*resultType), resultId, operand);
|
||||
return resultId;
|
||||
|
||||
default:
|
||||
|
||||
@@ -76,9 +76,10 @@ namespace Nz
|
||||
{
|
||||
node.expr->Visit(*this);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
|
||||
assert(exprType);
|
||||
|
||||
UInt32 typeId = m_writer.GetTypeId(exprType);
|
||||
UInt32 typeId = m_writer.GetTypeId(*exprType);
|
||||
|
||||
assert(node.indices.size() == 1);
|
||||
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
|
||||
@@ -88,7 +89,7 @@ namespace Nz
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
PointerChainAccess pointerChainAccess;
|
||||
pointerChainAccess.exprType = &exprType;
|
||||
pointerChainAccess.exprType = exprType;
|
||||
pointerChainAccess.indices = { indexId };
|
||||
pointerChainAccess.pointedTypeId = pointer.pointedTypeId;
|
||||
pointerChainAccess.pointerId = pointer.pointerId;
|
||||
@@ -98,7 +99,7 @@ namespace Nz
|
||||
},
|
||||
[&](PointerChainAccess& pointerChainAccess)
|
||||
{
|
||||
pointerChainAccess.exprType = &exprType;
|
||||
pointerChainAccess.exprType = exprType;
|
||||
pointerChainAccess.indices.push_back(indexId);
|
||||
},
|
||||
[&](const Value& value)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <Nazara/Shader/SpirvBlock.hpp>
|
||||
#include <Nazara/Shader/SpirvWriter.hpp>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
@@ -61,11 +62,12 @@ namespace Nz
|
||||
}
|
||||
else
|
||||
{
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node);
|
||||
|
||||
assert(swizzledPointer.componentCount == 1);
|
||||
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME
|
||||
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node);
|
||||
assert(exprType);
|
||||
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(*exprType, swizzledPointer.storage); //< FIXME
|
||||
|
||||
// Access chain
|
||||
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
|
||||
@@ -86,14 +88,15 @@ namespace Nz
|
||||
{
|
||||
node.expr->Visit(*this);
|
||||
|
||||
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
|
||||
const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
|
||||
assert(exprType);
|
||||
|
||||
std::visit(Overloaded
|
||||
{
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
UInt32 resultId = m_visitor.AllocateResultId();
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
|
||||
UInt32 pointerType = m_writer.RegisterPointerType(*exprType, pointer.storage); //< FIXME
|
||||
|
||||
assert(node.indices.size() == 1);
|
||||
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
|
||||
@@ -117,13 +120,14 @@ namespace Nz
|
||||
{
|
||||
[&](const Pointer& pointer)
|
||||
{
|
||||
const auto& expressionType = GetExpressionType(*node.expression);
|
||||
assert(IsVectorType(expressionType));
|
||||
const ShaderAst::ExpressionType* expressionType = GetExpressionType(*node.expression);
|
||||
assert(expressionType);
|
||||
assert(IsVectorType(*expressionType));
|
||||
|
||||
SwizzledPointer swizzledPointer;
|
||||
swizzledPointer.pointerId = pointer.pointerId;
|
||||
swizzledPointer.storage = pointer.storage;
|
||||
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(expressionType);
|
||||
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(*expressionType);
|
||||
swizzledPointer.componentCount = node.componentCount;
|
||||
swizzledPointer.swizzleIndices = node.components;
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ namespace Nz
|
||||
for (const auto& parameter : node.parameters)
|
||||
{
|
||||
auto& var = func.variables.emplace_back();
|
||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
|
||||
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(*GetExpressionType(*parameter), SpirvStorageClass::Function));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user