Shader: Add support for partial sanitization

This commit is contained in:
SirLynix 2022-03-25 12:54:51 +01:00
parent a54f70fd24
commit 8146ec251a
31 changed files with 1105 additions and 521 deletions

View File

@ -65,7 +65,7 @@ namespace Nz
NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/);
private:
void Validate(ShaderAst::Module& module);
ShaderAst::ModulePtr Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options);
NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated);

View File

@ -57,6 +57,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
virtual ExpressionPtr Clone(StructTypeExpression& node);
virtual ExpressionPtr Clone(SwizzleExpression& node);
virtual ExpressionPtr Clone(TypeExpression& node);
virtual ExpressionPtr Clone(VariableValueExpression& node);
virtual ExpressionPtr Clone(UnaryExpression& node);

View File

@ -49,6 +49,7 @@ namespace Nz::ShaderAst
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
inline bool Compare(const TypeExpression& lhs, const TypeExpression& rhs);
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs);
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);

View File

@ -407,6 +407,14 @@ namespace Nz::ShaderAst
return true;
}
bool Compare(const TypeExpression& lhs, const TypeExpression& rhs)
{
if (!Compare(lhs.typeId, rhs.typeId))
return false;
return true;
}
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs)
{
if (!Compare(lhs.variableId, rhs.variableId))

View File

@ -36,7 +36,7 @@ namespace Nz::ShaderAst
struct Options
{
std::function<const ConstantValue&(std::size_t constantId)> constantQueryCallback;
std::function<const ConstantValue*(std::size_t constantId)> constantQueryCallback;
};
protected:

View File

@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_EXPRESSION(TypeExpression)
NAZARA_SHADERAST_EXPRESSION(VariableValueExpression)
NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement)

View File

@ -37,6 +37,7 @@ namespace Nz::ShaderAst
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(TypeExpression& node) override;
void Visit(VariableValueExpression& node) override;
void Visit(UnaryExpression& node) override;

View File

@ -40,6 +40,7 @@ namespace Nz::ShaderAst
void Serialize(IntrinsicFunctionExpression& node);
void Serialize(StructTypeExpression& node);
void Serialize(SwizzleExpression& node);
void Serialize(TypeExpression& node);
void Serialize(VariableValueExpression& node);
void Serialize(UnaryExpression& node);
void SerializeExpressionCommon(Expression& expr);

View File

@ -48,6 +48,7 @@ namespace Nz::ShaderAst
void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override;
void Visit(TypeExpression& node) override;
void Visit(VariableValueExpression& node) override;
void Visit(UnaryExpression& node) override;

View File

@ -37,7 +37,7 @@ namespace Nz::ShaderAst
using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>;
NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant);
NAZARA_SHADER_API ExpressionType GetConstantType(const ConstantValue& constant);
}
#endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP

View File

@ -215,6 +215,14 @@ namespace Nz::ShaderAst
ExpressionPtr expression;
};
struct NAZARA_SHADER_API TypeExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t typeId;
};
struct NAZARA_SHADER_API VariableValueExpression : Expression
{
NodeType GetType() const override;
@ -462,8 +470,8 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/Ast/AstNodeList.hpp>
inline const ExpressionType& GetExpressionType(Expression& expr);
inline ExpressionType& GetExpressionTypeMut(Expression& expr);
inline const ExpressionType* GetExpressionType(Expression& expr);
inline ExpressionType* GetExpressionTypeMut(Expression& expr);
inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType);

View File

@ -7,16 +7,14 @@
namespace Nz::ShaderAst
{
inline const ExpressionType& GetExpressionType(Expression& expr)
inline const ExpressionType* GetExpressionType(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
}
inline ExpressionType& GetExpressionTypeMut(Expression& expr)
inline ExpressionType* GetExpressionTypeMut(Expression& expr)
{
assert(expr.cachedExpressionType);
return expr.cachedExpressionType.value();
return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
}
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)

View File

@ -46,6 +46,7 @@ namespace Nz::ShaderAst
std::shared_ptr<ShaderModuleResolver> moduleResolver;
std::unordered_set<std::string> reservedIdentifiers;
std::unordered_map<UInt32, ConstantValue> optionValues;
bool allowPartialSanitization = false;
bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false;
bool removeAliases = false;
@ -60,6 +61,7 @@ namespace Nz::ShaderAst
private:
enum class IdentifierCategory;
enum class ValidationResult;
struct AstError;
struct CurrentFunctionData;
struct Environment;
@ -110,10 +112,12 @@ namespace Nz::ShaderAst
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
const ExpressionType* GetExpressionType(Expression& expr) const;
const ExpressionType& GetExpressionTypeSecure(Expression& expr) const;
Expression& MandatoryExpr(const ExpressionPtr& node) const;
Statement& MandatoryStatement(const StatementPtr& node) const;
@ -122,8 +126,9 @@ namespace Nz::ShaderAst
ExpressionPtr CacheResult(ExpressionPtr expression);
ConstantValue ComputeConstantValue(Expression& expr) const;
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const;
std::optional<ConstantValue> ComputeConstantValue(Expression& expr) const;
template<typename T> ValidationResult ComputeExprValue(ExpressionValue<T>& attribute) const;
template<typename T> ValidationResult ComputeExprValue(const ExpressionValue<T>& attribute, ExpressionValue<T>& targetAttribute);
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
void PreregisterIndices(const Module& module);
@ -131,49 +136,49 @@ namespace Nz::ShaderAst
void RegisterBuiltin();
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {});
std::size_t RegisterAlias(std::string name, std::optional<IdentifierData> aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index = {});
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {});
std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index = {});
void RegisterUnresolved(std::string name);
std::size_t RegisterVariable(std::string name, std::optional<ExpressionType> type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const AliasType& aliasType);
std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
std::optional<ExpressionType> ResolveTypeExpr(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
StatementPtr Unscope(StatementPtr node);
void Validate(DeclareAliasStatement& node);
void Validate(WhileStatement& node);
ValidationResult Validate(DeclareAliasStatement& node);
ValidationResult Validate(WhileStatement& node);
void Validate(AccessIndexExpression& node);
void Validate(AssignExpression& node);
void Validate(BinaryExpression& node);
void Validate(CallFunctionExpression& node);
void Validate(CastExpression& node);
void Validate(DeclareVariableStatement& node);
void Validate(IntrinsicExpression& node);
void Validate(SwizzleExpression& node);
void Validate(UnaryExpression& node);
void Validate(VariableValueExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr);
ValidationResult Validate(AccessIndexExpression& node);
ValidationResult Validate(AssignExpression& node);
ValidationResult Validate(BinaryExpression& node);
ValidationResult Validate(CallFunctionExpression& node);
ValidationResult Validate(CastExpression& node);
ValidationResult Validate(DeclareVariableStatement& node);
ValidationResult Validate(IntrinsicExpression& node);
ValidationResult Validate(SwizzleExpression& node);
ValidationResult Validate(UnaryExpression& node);
ValidationResult Validate(VariableValueExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType);
enum class IdentifierCategory
{
@ -184,9 +189,16 @@ namespace Nz::ShaderAst
Module,
Struct,
Type,
Unresolved,
Variable
};
enum class ValidationResult
{
Validated,
Unresolved
};
struct FunctionData
{
Bitset<> calledByFunctions;

View File

@ -173,7 +173,7 @@ namespace Nz::ShaderBuilder
{
auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>();
constantNode->value = std::move(value);
constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value);
constantNode->cachedExpressionType = ShaderAst::GetConstantType(constantNode->value);
return constantNode;
}

View File

@ -249,7 +249,7 @@ struct VertIn
}
[entry(vert), cond(Billboard)]
fn billboardMain(input: VertIn) -> VertOut
fn billboardMain(input: VertIn) -> VertToFrag
{
let size = input.billboardSizeRot.xy;
let sinCos = input.billboardSizeRot.zw;

View File

@ -25,7 +25,7 @@ namespace Nz
m_shaderModule = moduleResolver.Resolve(moduleName);
NazaraAssert(m_shaderModule, "invalid shader module");
Validate(*m_shaderModule);
m_shaderModule = Validate(*m_shaderModule, &m_optionIndexByName);
m_onShaderModuleUpdated.Connect(moduleResolver.OnModuleUpdated, [this, name = std::move(moduleName)](ShaderModuleResolver* resolver, const std::string& updatedModuleName)
{
@ -41,8 +41,7 @@ namespace Nz
try
{
// FIXME: Validate is destructive, in case of failure it can invalidate the shader
Validate(*newShaderModule);
m_shaderModule = Validate(*newShaderModule, &m_optionIndexByName);
}
catch (const std::exception& e)
{
@ -50,8 +49,6 @@ namespace Nz
return;
}
m_shaderModule = std::move(newShaderModule);
// Clear cache
m_combinations.clear();
@ -65,7 +62,7 @@ namespace Nz
{
NazaraAssert(m_shaderModule, "invalid shader module");
Validate(*m_shaderModule);
Validate(*m_shaderModule, &m_optionIndexByName);
}
const std::shared_ptr<ShaderModule>& UberShader::Get(const Config& config)
@ -85,13 +82,17 @@ namespace Nz
return it->second;
}
void UberShader::Validate(ShaderAst::Module& module)
ShaderAst::ModulePtr UberShader::Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options)
{
NazaraAssert(m_shaderStages != 0, "there must be at least one shader stage");
assert(options);
//TODO: Try to partially sanitize shader?
// Try to partially sanitize shader
std::size_t optionCount = 0;
ShaderAst::SanitizeVisitor::Options sanitizeOptions;
sanitizeOptions.allowPartialSanitization = true;
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module, sanitizeOptions);
ShaderStageTypeFlags supportedStageType;
@ -101,21 +102,24 @@ namespace Nz
supportedStageType |= stageType;
};
std::unordered_map<std::string, Option> optionByName;
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
{
//TODO: Check optionType
m_optionIndexByName[option.optName] = Option{
optionByName[option.optName] = Option{
CRC32(option.optName)
};
optionCount++;
};
ShaderAst::AstReflect reflect;
reflect.Reflect(*module.rootNode, callbacks);
reflect.Reflect(*sanitizedModule->rootNode, callbacks);
if ((m_shaderStages & supportedStageType) != m_shaderStages)
throw std::runtime_error("shader doesn't support all required shader stages");
*options = std::move(optionByName);
return sanitizedModule;
}
}

View File

@ -481,6 +481,16 @@ namespace Nz::ShaderAst
return clone;
}
ExpressionPtr AstCloner::Clone(TypeExpression& node)
{
auto clone = std::make_unique<TypeExpression>();
clone->typeId = node.typeId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(VariableValueExpression& node)
{
auto clone = std::make_unique<VariableValueExpression>();

View File

@ -862,7 +862,7 @@ namespace Nz::ShaderAst
const auto& constantExpr = static_cast<ConstantValueExpression&>(*expressions[i]);
if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value))
if (!constantValues.empty() && GetConstantType(constantValues.front()) != GetConstantType(constantExpr.value))
{
// Unhandled case, all cast parameters are expected to be of the same type
constantValues.clear();
@ -940,16 +940,24 @@ namespace Nz::ShaderAst
std::vector<BranchStatement::ConditionalStatement> statements;
StatementPtr elseStatement;
bool continuePropagation = true;
for (auto& condStatement : node.condStatements)
{
auto cond = CloneExpression(condStatement.condition);
if (cond->GetType() == NodeType::ConstantValueExpression)
if (continuePropagation && cond->GetType() == NodeType::ConstantValueExpression)
{
auto& constant = static_cast<ConstantValueExpression&>(*cond);
const ExpressionType& constantType = GetExpressionType(constant);
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean)
const ExpressionType* constantType = GetExpressionType(constant);
if (!constantType)
{
// unresolved type, can't continue propagating this branch
continuePropagation = false;
continue;
}
if (!IsPrimitiveType(*constantType) || std::get<PrimitiveType>(*constantType) != PrimitiveType::Boolean)
continue;
bool cValue = std::get<bool>(constant.value);
@ -1017,8 +1025,12 @@ namespace Nz::ShaderAst
if (!m_options.constantQueryCallback)
return AstCloner::Clone(node);
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId));
constant->cachedExpressionType = GetExpressionType(constant->value);
const ConstantValue* constantValue = m_options.constantQueryCallback(node.constantId);
if (!constantValue)
return AstCloner::Clone(node);
auto constant = ShaderBuilder::Constant(*constantValue);
constant->cachedExpressionType = GetConstantType(constant->value);
return constant;
}
@ -1155,7 +1167,7 @@ namespace Nz::ShaderAst
}, lhs.value);
if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value);
optimized->cachedExpressionType = GetConstantType(optimized->value);
return optimized;
}
@ -1221,7 +1233,7 @@ namespace Nz::ShaderAst
}, operand.value);
if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value);
optimized->cachedExpressionType = GetConstantType(optimized->value);
return optimized;
}

View File

@ -109,6 +109,11 @@ namespace Nz::ShaderAst
node.expression->Visit(*this);
}
void AstRecursiveVisitor::Visit(TypeExpression& node)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(VariableValueExpression& /*node*/)
{
/* Nothing to do */

View File

@ -174,6 +174,11 @@ namespace Nz::ShaderAst
SizeT(node.structTypeId);
}
void AstSerializerBase::Serialize(TypeExpression& node)
{
SizeT(node.typeId);
}
void AstSerializerBase::Serialize(FunctionExpression& node)
{
SizeT(node.funcId);

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <cassert>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst
@ -104,7 +105,10 @@ namespace Nz::ShaderAst
void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{
if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1)
const ExpressionType* exprType = GetExpressionType(node);
assert(exprType);
if (IsPrimitiveType(*exprType) && node.componentCount > 1)
// Swizzling more than a component on a primitive produces a rvalue (a.xxxx cannot be assigned)
m_expressionCategory = ExpressionCategory::RValue;
else
@ -133,6 +137,11 @@ namespace Nz::ShaderAst
}
}
void ShaderAstValueCategory::Visit(TypeExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(VariableValueExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;

View File

@ -8,7 +8,7 @@
namespace Nz::ShaderAst
{
ExpressionType GetExpressionType(const ConstantValue& constant)
ExpressionType GetConstantType(const ConstantValue& constant)
{
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{

View File

@ -15,10 +15,11 @@ namespace Nz::ShaderAst
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
{
const auto& targetFuncType = GetExpressionType(*node.targetFunction);
assert(std::holds_alternative<FunctionType>(targetFuncType));
const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction);
assert(targetFuncType);
assert(std::holds_alternative<FunctionType>(*targetFuncType));
const auto& funcType = std::get<FunctionType>(targetFuncType);
const auto& funcType = std::get<FunctionType>(*targetFuncType);
assert(m_currentFunctionIndex);
if (m_currentVariableDeclIndex)

File diff suppressed because it is too large Load Diff

View File

@ -36,7 +36,7 @@ namespace Nz
AstRecursiveVisitor::Visit(node);
assert(currentFunction);
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex);
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex);
}
void Visit(ShaderAst::ConditionalExpression& /*node*/) override
@ -307,7 +307,7 @@ namespace Nz
Append(type.GetResultingValue());
}
void GlslWriter::Append(const ShaderAst::FunctionType& functionType)
void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
{
throw std::runtime_error("unexpected FunctionType");
}
@ -829,8 +829,9 @@ namespace Nz
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(IsStructType(exprType));
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
assert(exprType);
assert(IsStructType(*exprType));
for (const std::string& identifier : node.identifiers)
Append(".", identifier);
@ -840,8 +841,9 @@ namespace Nz
{
Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr);
assert(!IsStructType(exprType));
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
assert(exprType);
assert(!IsStructType(*exprType));
// Array access
assert(node.indices.size() == 1);
@ -1326,9 +1328,10 @@ namespace Nz
{
assert(node.returnExpr);
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr);
assert(IsStructType(returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex;
const ShaderAst::ExpressionType* returnType = GetExpressionType(*node.returnExpr);
assert(returnType);
assert(IsStructType(*returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(*returnType).structIndex;
const auto& structData = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName;

View File

@ -182,7 +182,7 @@ namespace Nz
type.GetExpression()->Visit(*this);
}
void LangWriter::Append(const ShaderAst::FunctionType& functionType)
void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
{
throw std::runtime_error("unexpected function type");
}

View File

@ -263,7 +263,7 @@ namespace Nz::ShaderLang
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
const std::string& versionStr = std::get<std::string>(constantValue.value);
@ -302,7 +302,7 @@ namespace Nz::ShaderLang
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
const std::string& uuidStr = std::get<std::string>(constantValue.value);

View File

@ -67,9 +67,9 @@ namespace Nz
throw std::runtime_error("unexpected type");
};
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left);
const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right);
const ShaderAst::ExpressionType& resultType = *GetExpressionType(node);
const ShaderAst::ExpressionType& leftType = *GetExpressionType(*node.left);
const ShaderAst::ExpressionType& rightType = *GetExpressionType(*node.right);
ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType);
//ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType);
@ -405,7 +405,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
{
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex;
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex;
UInt32 funcId = 0;
for (const auto& [funcIndex, func] : m_funcData)
@ -434,7 +434,7 @@ namespace Nz
UInt32 resultId = AllocateResultId();
m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender)
{
appender(m_writer.GetTypeId(ShaderAst::GetExpressionType(node)));
appender(m_writer.GetTypeId(*ShaderAst::GetExpressionType(node)));
appender(resultId);
appender(funcId);
@ -718,9 +718,11 @@ namespace Nz
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(parameterType);
assert(IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
@ -733,10 +735,11 @@ namespace Nz
case ShaderAst::IntrinsicType::DotProduct:
{
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType));
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
assert(vecExprType);
assert(IsVectorType(*vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
UInt32 typeId = m_writer.GetTypeId(vecType.type);
@ -754,9 +757,10 @@ namespace Nz
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(parameterType);
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 param = EvaluateExpression(node.parameters[0]);
UInt32 resultId = m_writer.AllocateResultId();
@ -770,10 +774,11 @@ namespace Nz
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType));
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
assert(vecExprType);
assert(IsVectorType(*vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
UInt32 typeId = m_writer.GetTypeId(vecType.type);
UInt32 vec = EvaluateExpression(node.parameters[0]);
@ -790,15 +795,16 @@ namespace Nz
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(parameterType);
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(parameterType))
basicType = std::get<ShaderAst::PrimitiveType>(parameterType);
else if (IsVectorType(parameterType))
basicType = std::get<ShaderAst::VectorType>(parameterType).type;
if (IsPrimitiveType(*parameterType))
basicType = std::get<ShaderAst::PrimitiveType>(*parameterType);
else if (IsVectorType(*parameterType))
basicType = std::get<ShaderAst::VectorType>(*parameterType).type;
else
throw std::runtime_error("unexpected expression type");
@ -837,10 +843,11 @@ namespace Nz
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType));
const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
assert(vecExprType);
assert(IsVectorType(*vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType);
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
UInt32 typeId = m_writer.GetTypeId(vecType);
UInt32 vec = EvaluateExpression(node.parameters[0]);
@ -856,9 +863,10 @@ namespace Nz
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(parameterType);
assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
@ -873,9 +881,10 @@ namespace Nz
{
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(parameterType));
UInt32 typeId = m_writer.GetTypeId(parameterType);
const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(parameterType);
assert(IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]);
@ -951,20 +960,22 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{
const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression);
const ShaderAst::ExpressionType* swizzledExpressionType = GetExpressionType(*node.expression);
assert(swizzledExpressionType);
UInt32 exprResultId = EvaluateExpression(node.expression);
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node);
const ShaderAst::ExpressionType* targetExprType = GetExpressionType(node);
assert(targetExprType);
if (node.componentCount > 1)
{
assert(IsVectorType(targetExprType));
assert(IsVectorType(*targetExprType));
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(targetExprType);
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(*targetExprType);
UInt32 resultId = m_writer.AllocateResultId();
if (IsVectorType(swizzledExpressionType))
if (IsVectorType(*swizzledExpressionType))
{
// Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
@ -980,7 +991,7 @@ namespace Nz
}
else
{
assert(IsPrimitiveType(swizzledExpressionType));
assert(IsPrimitiveType(*swizzledExpressionType));
// Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
@ -995,10 +1006,10 @@ namespace Nz
PushResultId(resultId);
}
else if (IsVectorType(swizzledExpressionType))
else if (IsVectorType(*swizzledExpressionType))
{
assert(IsPrimitiveType(targetExprType));
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType);
assert(IsPrimitiveType(*targetExprType));
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(*targetExprType);
// Extract a single component from the vector
assert(node.componentCount == 1);
@ -1011,8 +1022,8 @@ namespace Nz
else
{
// Swizzling a primitive to itself (a.x for example), don't do anything
assert(IsPrimitiveType(swizzledExpressionType));
assert(IsPrimitiveType(targetExprType));
assert(IsPrimitiveType(*swizzledExpressionType));
assert(IsPrimitiveType(*targetExprType));
assert(node.componentCount == 1);
assert(node.components[0] == 0);
@ -1022,8 +1033,11 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
{
const ShaderAst::ExpressionType& resultType = GetExpressionType(node);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expression);
const ShaderAst::ExpressionType* resultType = GetExpressionType(node);
assert(resultType);
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expression);
assert(exprType);
UInt32 operand = EvaluateExpression(node.expression);
@ -1033,11 +1047,11 @@ namespace Nz
{
case ShaderAst::UnaryType::LogicalNot:
{
assert(IsPrimitiveType(exprType));
assert(std::get<ShaderAst::PrimitiveType>(resultType) == ShaderAst::PrimitiveType::Boolean);
assert(IsPrimitiveType(*exprType));
assert(std::get<ShaderAst::PrimitiveType>(*resultType) == ShaderAst::PrimitiveType::Boolean);
UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(resultType), resultId, operand);
m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(*resultType), resultId, operand);
return resultId;
}
@ -1045,10 +1059,10 @@ namespace Nz
case ShaderAst::UnaryType::Minus:
{
ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType);
else if (IsVectorType(exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type;
if (IsPrimitiveType(*exprType))
basicType = std::get<ShaderAst::PrimitiveType>(*exprType);
else if (IsVectorType(*exprType))
basicType = std::get<ShaderAst::VectorType>(*exprType).type;
else
throw std::runtime_error("unexpected expression type");
@ -1057,12 +1071,12 @@ namespace Nz
switch (basicType)
{
case ShaderAst::PrimitiveType::Float32:
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand);
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(*resultType), resultId, operand);
return resultId;
case ShaderAst::PrimitiveType::Int32:
case ShaderAst::PrimitiveType::UInt32:
m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(resultType), resultId, operand);
m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(*resultType), resultId, operand);
return resultId;
default:

View File

@ -76,9 +76,10 @@ namespace Nz
{
node.expr->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
assert(exprType);
UInt32 typeId = m_writer.GetTypeId(exprType);
UInt32 typeId = m_writer.GetTypeId(*exprType);
assert(node.indices.size() == 1);
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
@ -88,7 +89,7 @@ namespace Nz
[&](const Pointer& pointer)
{
PointerChainAccess pointerChainAccess;
pointerChainAccess.exprType = &exprType;
pointerChainAccess.exprType = exprType;
pointerChainAccess.indices = { indexId };
pointerChainAccess.pointedTypeId = pointer.pointedTypeId;
pointerChainAccess.pointerId = pointer.pointerId;
@ -98,7 +99,7 @@ namespace Nz
},
[&](PointerChainAccess& pointerChainAccess)
{
pointerChainAccess.exprType = &exprType;
pointerChainAccess.exprType = exprType;
pointerChainAccess.indices.push_back(indexId);
},
[&](const Value& value)

View File

@ -8,6 +8,7 @@
#include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp>
#include <cassert>
#include <numeric>
#include <Nazara/Shader/Debug.hpp>
@ -61,11 +62,12 @@ namespace Nz
}
else
{
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node);
assert(swizzledPointer.componentCount == 1);
UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node);
assert(exprType);
UInt32 pointerType = m_writer.RegisterPointerType(*exprType, swizzledPointer.storage); //< FIXME
// Access chain
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
@ -86,14 +88,15 @@ namespace Nz
{
node.expr->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node);
const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
assert(exprType);
std::visit(Overloaded
{
[&](const Pointer& pointer)
{
UInt32 resultId = m_visitor.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME
UInt32 pointerType = m_writer.RegisterPointerType(*exprType, pointer.storage); //< FIXME
assert(node.indices.size() == 1);
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
@ -117,13 +120,14 @@ namespace Nz
{
[&](const Pointer& pointer)
{
const auto& expressionType = GetExpressionType(*node.expression);
assert(IsVectorType(expressionType));
const ShaderAst::ExpressionType* expressionType = GetExpressionType(*node.expression);
assert(expressionType);
assert(IsVectorType(*expressionType));
SwizzledPointer swizzledPointer;
swizzledPointer.pointerId = pointer.pointerId;
swizzledPointer.storage = pointer.storage;
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(expressionType);
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(*expressionType);
swizzledPointer.componentCount = node.componentCount;
swizzledPointer.swizzleIndices = node.components;

View File

@ -98,7 +98,7 @@ namespace Nz
for (const auto& parameter : node.parameters)
{
auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function));
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(*GetExpressionType(*parameter), SpirvStorageClass::Function));
}
}