Shader: Add support for partial sanitization

This commit is contained in:
SirLynix 2022-03-25 12:54:51 +01:00
parent a54f70fd24
commit 8146ec251a
31 changed files with 1105 additions and 521 deletions

View File

@ -65,7 +65,7 @@ namespace Nz
NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/); NazaraSignal(OnShaderUpdated, UberShader* /*uberShader*/);
private: private:
void Validate(ShaderAst::Module& module); ShaderAst::ModulePtr Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options);
NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated); NazaraSlot(ShaderModuleResolver, OnModuleUpdated, m_onShaderModuleUpdated);

View File

@ -57,6 +57,7 @@ namespace Nz::ShaderAst
virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node); virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node);
virtual ExpressionPtr Clone(StructTypeExpression& node); virtual ExpressionPtr Clone(StructTypeExpression& node);
virtual ExpressionPtr Clone(SwizzleExpression& node); virtual ExpressionPtr Clone(SwizzleExpression& node);
virtual ExpressionPtr Clone(TypeExpression& node);
virtual ExpressionPtr Clone(VariableValueExpression& node); virtual ExpressionPtr Clone(VariableValueExpression& node);
virtual ExpressionPtr Clone(UnaryExpression& node); virtual ExpressionPtr Clone(UnaryExpression& node);

View File

@ -49,6 +49,7 @@ namespace Nz::ShaderAst
inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs); inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs);
inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs); inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs);
inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs); inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs);
inline bool Compare(const TypeExpression& lhs, const TypeExpression& rhs);
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs); inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs);
inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs); inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs);

View File

@ -407,6 +407,14 @@ namespace Nz::ShaderAst
return true; return true;
} }
bool Compare(const TypeExpression& lhs, const TypeExpression& rhs)
{
if (!Compare(lhs.typeId, rhs.typeId))
return false;
return true;
}
inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs) inline bool Compare(const VariableValueExpression& lhs, const VariableValueExpression& rhs)
{ {
if (!Compare(lhs.variableId, rhs.variableId)) if (!Compare(lhs.variableId, rhs.variableId))

View File

@ -36,7 +36,7 @@ namespace Nz::ShaderAst
struct Options struct Options
{ {
std::function<const ConstantValue&(std::size_t constantId)> constantQueryCallback; std::function<const ConstantValue*(std::size_t constantId)> constantQueryCallback;
}; };
protected: protected:

View File

@ -45,6 +45,7 @@ NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression)
NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression) NAZARA_SHADERAST_EXPRESSION(IntrinsicFunctionExpression)
NAZARA_SHADERAST_EXPRESSION(StructTypeExpression) NAZARA_SHADERAST_EXPRESSION(StructTypeExpression)
NAZARA_SHADERAST_EXPRESSION(SwizzleExpression) NAZARA_SHADERAST_EXPRESSION(SwizzleExpression)
NAZARA_SHADERAST_EXPRESSION(TypeExpression)
NAZARA_SHADERAST_EXPRESSION(VariableValueExpression) NAZARA_SHADERAST_EXPRESSION(VariableValueExpression)
NAZARA_SHADERAST_EXPRESSION(UnaryExpression) NAZARA_SHADERAST_EXPRESSION(UnaryExpression)
NAZARA_SHADERAST_STATEMENT(BranchStatement) NAZARA_SHADERAST_STATEMENT(BranchStatement)

View File

@ -37,6 +37,7 @@ namespace Nz::ShaderAst
void Visit(IntrinsicFunctionExpression& node) override; void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override; void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override; void Visit(SwizzleExpression& node) override;
void Visit(TypeExpression& node) override;
void Visit(VariableValueExpression& node) override; void Visit(VariableValueExpression& node) override;
void Visit(UnaryExpression& node) override; void Visit(UnaryExpression& node) override;

View File

@ -40,6 +40,7 @@ namespace Nz::ShaderAst
void Serialize(IntrinsicFunctionExpression& node); void Serialize(IntrinsicFunctionExpression& node);
void Serialize(StructTypeExpression& node); void Serialize(StructTypeExpression& node);
void Serialize(SwizzleExpression& node); void Serialize(SwizzleExpression& node);
void Serialize(TypeExpression& node);
void Serialize(VariableValueExpression& node); void Serialize(VariableValueExpression& node);
void Serialize(UnaryExpression& node); void Serialize(UnaryExpression& node);
void SerializeExpressionCommon(Expression& expr); void SerializeExpressionCommon(Expression& expr);

View File

@ -48,6 +48,7 @@ namespace Nz::ShaderAst
void Visit(IntrinsicFunctionExpression& node) override; void Visit(IntrinsicFunctionExpression& node) override;
void Visit(StructTypeExpression& node) override; void Visit(StructTypeExpression& node) override;
void Visit(SwizzleExpression& node) override; void Visit(SwizzleExpression& node) override;
void Visit(TypeExpression& node) override;
void Visit(VariableValueExpression& node) override; void Visit(VariableValueExpression& node) override;
void Visit(UnaryExpression& node) override; void Visit(UnaryExpression& node) override;

View File

@ -37,7 +37,7 @@ namespace Nz::ShaderAst
using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>; using ConstantValue = TypeListInstantiate<ConstantTypes, std::variant>;
NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant); NAZARA_SHADER_API ExpressionType GetConstantType(const ConstantValue& constant);
} }
#endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP #endif // NAZARA_SHADER_AST_CONSTANTVALUE_HPP

View File

@ -215,6 +215,14 @@ namespace Nz::ShaderAst
ExpressionPtr expression; ExpressionPtr expression;
}; };
struct NAZARA_SHADER_API TypeExpression : Expression
{
NodeType GetType() const override;
void Visit(AstExpressionVisitor& visitor) override;
std::size_t typeId;
};
struct NAZARA_SHADER_API VariableValueExpression : Expression struct NAZARA_SHADER_API VariableValueExpression : Expression
{ {
NodeType GetType() const override; NodeType GetType() const override;
@ -462,8 +470,8 @@ namespace Nz::ShaderAst
#include <Nazara/Shader/Ast/AstNodeList.hpp> #include <Nazara/Shader/Ast/AstNodeList.hpp>
inline const ExpressionType& GetExpressionType(Expression& expr); inline const ExpressionType* GetExpressionType(Expression& expr);
inline ExpressionType& GetExpressionTypeMut(Expression& expr); inline ExpressionType* GetExpressionTypeMut(Expression& expr);
inline bool IsExpression(NodeType nodeType); inline bool IsExpression(NodeType nodeType);
inline bool IsStatement(NodeType nodeType); inline bool IsStatement(NodeType nodeType);

View File

@ -7,16 +7,14 @@
namespace Nz::ShaderAst namespace Nz::ShaderAst
{ {
inline const ExpressionType& GetExpressionType(Expression& expr) inline const ExpressionType* GetExpressionType(Expression& expr)
{ {
assert(expr.cachedExpressionType); return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
return expr.cachedExpressionType.value();
} }
inline ExpressionType& GetExpressionTypeMut(Expression& expr) inline ExpressionType* GetExpressionTypeMut(Expression& expr)
{ {
assert(expr.cachedExpressionType); return (expr.cachedExpressionType) ? &expr.cachedExpressionType.value() : nullptr;
return expr.cachedExpressionType.value();
} }
inline const ExpressionType& ResolveAlias(const ExpressionType& exprType) inline const ExpressionType& ResolveAlias(const ExpressionType& exprType)

View File

@ -46,6 +46,7 @@ namespace Nz::ShaderAst
std::shared_ptr<ShaderModuleResolver> moduleResolver; std::shared_ptr<ShaderModuleResolver> moduleResolver;
std::unordered_set<std::string> reservedIdentifiers; std::unordered_set<std::string> reservedIdentifiers;
std::unordered_map<UInt32, ConstantValue> optionValues; std::unordered_map<UInt32, ConstantValue> optionValues;
bool allowPartialSanitization = false;
bool makeVariableNameUnique = false; bool makeVariableNameUnique = false;
bool reduceLoopsToWhile = false; bool reduceLoopsToWhile = false;
bool removeAliases = false; bool removeAliases = false;
@ -60,6 +61,7 @@ namespace Nz::ShaderAst
private: private:
enum class IdentifierCategory; enum class IdentifierCategory;
enum class ValidationResult;
struct AstError; struct AstError;
struct CurrentFunctionData; struct CurrentFunctionData;
struct Environment; struct Environment;
@ -110,10 +112,12 @@ namespace Nz::ShaderAst
template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const; template<typename F> const IdentifierData* FindIdentifier(const std::string_view& identifierName, F&& functor) const;
const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const; const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName) const;
template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const; template<typename F> const IdentifierData* FindIdentifier(const Environment& environment, const std::string_view& identifierName, F&& functor) const;
TypeParameter FindTypeParameter(const std::string_view& identifierName) const;
ExpressionPtr HandleIdentifier(const IdentifierData* identifierData); ExpressionPtr HandleIdentifier(const IdentifierData* identifierData);
const ExpressionType* GetExpressionType(Expression& expr) const;
const ExpressionType& GetExpressionTypeSecure(Expression& expr) const;
Expression& MandatoryExpr(const ExpressionPtr& node) const; Expression& MandatoryExpr(const ExpressionPtr& node) const;
Statement& MandatoryStatement(const StatementPtr& node) const; Statement& MandatoryStatement(const StatementPtr& node) const;
@ -122,8 +126,9 @@ namespace Nz::ShaderAst
ExpressionPtr CacheResult(ExpressionPtr expression); ExpressionPtr CacheResult(ExpressionPtr expression);
ConstantValue ComputeConstantValue(Expression& expr) const; std::optional<ConstantValue> ComputeConstantValue(Expression& expr) const;
template<typename T> const T& ComputeExprValue(ExpressionValue<T>& attribute) const; template<typename T> ValidationResult ComputeExprValue(ExpressionValue<T>& attribute) const;
template<typename T> ValidationResult ComputeExprValue(const ExpressionValue<T>& attribute, ExpressionValue<T>& targetAttribute);
template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const; template<typename T> std::unique_ptr<T> PropagateConstants(T& node) const;
void PreregisterIndices(const Module& module); void PreregisterIndices(const Module& module);
@ -131,49 +136,49 @@ namespace Nz::ShaderAst
void RegisterBuiltin(); void RegisterBuiltin();
std::size_t RegisterAlias(std::string name, IdentifierData aliasData, std::optional<std::size_t> index = {}); std::size_t RegisterAlias(std::string name, std::optional<IdentifierData> aliasData, std::optional<std::size_t> index = {});
std::size_t RegisterConstant(std::string name, ConstantValue value, std::optional<std::size_t> index = {}); std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index = {});
std::size_t RegisterFunction(std::string name, FunctionData funcData, std::optional<std::size_t> index = {}); std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index = {});
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex); std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
std::size_t RegisterStruct(std::string name, StructDescription* description, std::optional<std::size_t> index = {}); std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, ExpressionType expressionType, std::optional<std::size_t> index = {}); std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index = {});
std::size_t RegisterType(std::string name, PartialType partialType, std::optional<std::size_t> index = {}); std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index = {});
std::size_t RegisterVariable(std::string name, ExpressionType type, std::optional<std::size_t> index = {}); void RegisterUnresolved(std::string name);
std::size_t RegisterVariable(std::string name, std::optional<ExpressionType> type, std::optional<std::size_t> index = {});
const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const; const IdentifierData* ResolveAliasIdentifier(const IdentifierData* identifier) const;
void ResolveFunctions(); void ResolveFunctions();
const ExpressionPtr& ResolveCondExpression(ConditionalExpression& node);
std::size_t ResolveStruct(const AliasType& aliasType); std::size_t ResolveStruct(const AliasType& aliasType);
std::size_t ResolveStruct(const ExpressionType& exprType); std::size_t ResolveStruct(const ExpressionType& exprType);
std::size_t ResolveStruct(const IdentifierType& identifierType); std::size_t ResolveStruct(const IdentifierType& identifierType);
std::size_t ResolveStruct(const StructType& structType); std::size_t ResolveStruct(const StructType& structType);
std::size_t ResolveStruct(const UniformType& uniformType); std::size_t ResolveStruct(const UniformType& uniformType);
ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false); ExpressionType ResolveType(const ExpressionType& exprType, bool resolveAlias = false);
ExpressionType ResolveType(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false); std::optional<ExpressionType> ResolveTypeExpr(const ExpressionValue<ExpressionType>& exprTypeValue, bool resolveAlias = false);
void SanitizeIdentifier(std::string& identifier); void SanitizeIdentifier(std::string& identifier);
MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error); MultiStatementPtr SanitizeInternal(MultiStatement& rootNode, std::string* error);
void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const; ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) const;
void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const; void TypeMustMatch(const ExpressionType& left, const ExpressionType& right) const;
StatementPtr Unscope(StatementPtr node); StatementPtr Unscope(StatementPtr node);
void Validate(DeclareAliasStatement& node); ValidationResult Validate(DeclareAliasStatement& node);
void Validate(WhileStatement& node); ValidationResult Validate(WhileStatement& node);
void Validate(AccessIndexExpression& node); ValidationResult Validate(AccessIndexExpression& node);
void Validate(AssignExpression& node); ValidationResult Validate(AssignExpression& node);
void Validate(BinaryExpression& node); ValidationResult Validate(BinaryExpression& node);
void Validate(CallFunctionExpression& node); ValidationResult Validate(CallFunctionExpression& node);
void Validate(CastExpression& node); ValidationResult Validate(CastExpression& node);
void Validate(DeclareVariableStatement& node); ValidationResult Validate(DeclareVariableStatement& node);
void Validate(IntrinsicExpression& node); ValidationResult Validate(IntrinsicExpression& node);
void Validate(SwizzleExpression& node); ValidationResult Validate(SwizzleExpression& node);
void Validate(UnaryExpression& node); ValidationResult Validate(UnaryExpression& node);
void Validate(VariableValueExpression& node); ValidationResult Validate(VariableValueExpression& node);
ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionPtr& leftExpr, const ExpressionPtr& rightExpr); ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType);
enum class IdentifierCategory enum class IdentifierCategory
{ {
@ -184,9 +189,16 @@ namespace Nz::ShaderAst
Module, Module,
Struct, Struct,
Type, Type,
Unresolved,
Variable Variable
}; };
enum class ValidationResult
{
Validated,
Unresolved
};
struct FunctionData struct FunctionData
{ {
Bitset<> calledByFunctions; Bitset<> calledByFunctions;

View File

@ -173,7 +173,7 @@ namespace Nz::ShaderBuilder
{ {
auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>(); auto constantNode = std::make_unique<ShaderAst::ConstantValueExpression>();
constantNode->value = std::move(value); constantNode->value = std::move(value);
constantNode->cachedExpressionType = ShaderAst::GetExpressionType(constantNode->value); constantNode->cachedExpressionType = ShaderAst::GetConstantType(constantNode->value);
return constantNode; return constantNode;
} }

View File

@ -249,7 +249,7 @@ struct VertIn
} }
[entry(vert), cond(Billboard)] [entry(vert), cond(Billboard)]
fn billboardMain(input: VertIn) -> VertOut fn billboardMain(input: VertIn) -> VertToFrag
{ {
let size = input.billboardSizeRot.xy; let size = input.billboardSizeRot.xy;
let sinCos = input.billboardSizeRot.zw; let sinCos = input.billboardSizeRot.zw;

View File

@ -25,7 +25,7 @@ namespace Nz
m_shaderModule = moduleResolver.Resolve(moduleName); m_shaderModule = moduleResolver.Resolve(moduleName);
NazaraAssert(m_shaderModule, "invalid shader module"); NazaraAssert(m_shaderModule, "invalid shader module");
Validate(*m_shaderModule); m_shaderModule = Validate(*m_shaderModule, &m_optionIndexByName);
m_onShaderModuleUpdated.Connect(moduleResolver.OnModuleUpdated, [this, name = std::move(moduleName)](ShaderModuleResolver* resolver, const std::string& updatedModuleName) m_onShaderModuleUpdated.Connect(moduleResolver.OnModuleUpdated, [this, name = std::move(moduleName)](ShaderModuleResolver* resolver, const std::string& updatedModuleName)
{ {
@ -41,8 +41,7 @@ namespace Nz
try try
{ {
// FIXME: Validate is destructive, in case of failure it can invalidate the shader m_shaderModule = Validate(*newShaderModule, &m_optionIndexByName);
Validate(*newShaderModule);
} }
catch (const std::exception& e) catch (const std::exception& e)
{ {
@ -50,8 +49,6 @@ namespace Nz
return; return;
} }
m_shaderModule = std::move(newShaderModule);
// Clear cache // Clear cache
m_combinations.clear(); m_combinations.clear();
@ -65,7 +62,7 @@ namespace Nz
{ {
NazaraAssert(m_shaderModule, "invalid shader module"); NazaraAssert(m_shaderModule, "invalid shader module");
Validate(*m_shaderModule); Validate(*m_shaderModule, &m_optionIndexByName);
} }
const std::shared_ptr<ShaderModule>& UberShader::Get(const Config& config) const std::shared_ptr<ShaderModule>& UberShader::Get(const Config& config)
@ -85,13 +82,17 @@ namespace Nz
return it->second; return it->second;
} }
void UberShader::Validate(ShaderAst::Module& module) ShaderAst::ModulePtr UberShader::Validate(const ShaderAst::Module& module, std::unordered_map<std::string, Option>* options)
{ {
NazaraAssert(m_shaderStages != 0, "there must be at least one shader stage"); NazaraAssert(m_shaderStages != 0, "there must be at least one shader stage");
assert(options);
//TODO: Try to partially sanitize shader? // Try to partially sanitize shader
std::size_t optionCount = 0; ShaderAst::SanitizeVisitor::Options sanitizeOptions;
sanitizeOptions.allowPartialSanitization = true;
ShaderAst::ModulePtr sanitizedModule = ShaderAst::Sanitize(module, sanitizeOptions);
ShaderStageTypeFlags supportedStageType; ShaderStageTypeFlags supportedStageType;
@ -101,21 +102,24 @@ namespace Nz
supportedStageType |= stageType; supportedStageType |= stageType;
}; };
std::unordered_map<std::string, Option> optionByName;
callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option) callbacks.onOptionDeclaration = [&](const ShaderAst::DeclareOptionStatement& option)
{ {
//TODO: Check optionType //TODO: Check optionType
m_optionIndexByName[option.optName] = Option{ optionByName[option.optName] = Option{
CRC32(option.optName) CRC32(option.optName)
}; };
optionCount++;
}; };
ShaderAst::AstReflect reflect; ShaderAst::AstReflect reflect;
reflect.Reflect(*module.rootNode, callbacks); reflect.Reflect(*sanitizedModule->rootNode, callbacks);
if ((m_shaderStages & supportedStageType) != m_shaderStages) if ((m_shaderStages & supportedStageType) != m_shaderStages)
throw std::runtime_error("shader doesn't support all required shader stages"); throw std::runtime_error("shader doesn't support all required shader stages");
*options = std::move(optionByName);
return sanitizedModule;
} }
} }

View File

@ -481,6 +481,16 @@ namespace Nz::ShaderAst
return clone; return clone;
} }
ExpressionPtr AstCloner::Clone(TypeExpression& node)
{
auto clone = std::make_unique<TypeExpression>();
clone->typeId = node.typeId;
clone->cachedExpressionType = node.cachedExpressionType;
return clone;
}
ExpressionPtr AstCloner::Clone(VariableValueExpression& node) ExpressionPtr AstCloner::Clone(VariableValueExpression& node)
{ {
auto clone = std::make_unique<VariableValueExpression>(); auto clone = std::make_unique<VariableValueExpression>();

View File

@ -862,7 +862,7 @@ namespace Nz::ShaderAst
const auto& constantExpr = static_cast<ConstantValueExpression&>(*expressions[i]); const auto& constantExpr = static_cast<ConstantValueExpression&>(*expressions[i]);
if (!constantValues.empty() && GetExpressionType(constantValues.front()) != GetExpressionType(constantExpr.value)) if (!constantValues.empty() && GetConstantType(constantValues.front()) != GetConstantType(constantExpr.value))
{ {
// Unhandled case, all cast parameters are expected to be of the same type // Unhandled case, all cast parameters are expected to be of the same type
constantValues.clear(); constantValues.clear();
@ -940,16 +940,24 @@ namespace Nz::ShaderAst
std::vector<BranchStatement::ConditionalStatement> statements; std::vector<BranchStatement::ConditionalStatement> statements;
StatementPtr elseStatement; StatementPtr elseStatement;
bool continuePropagation = true;
for (auto& condStatement : node.condStatements) for (auto& condStatement : node.condStatements)
{ {
auto cond = CloneExpression(condStatement.condition); auto cond = CloneExpression(condStatement.condition);
if (cond->GetType() == NodeType::ConstantValueExpression) if (continuePropagation && cond->GetType() == NodeType::ConstantValueExpression)
{ {
auto& constant = static_cast<ConstantValueExpression&>(*cond); auto& constant = static_cast<ConstantValueExpression&>(*cond);
const ExpressionType& constantType = GetExpressionType(constant); const ExpressionType* constantType = GetExpressionType(constant);
if (!IsPrimitiveType(constantType) || std::get<PrimitiveType>(constantType) != PrimitiveType::Boolean) if (!constantType)
{
// unresolved type, can't continue propagating this branch
continuePropagation = false;
continue;
}
if (!IsPrimitiveType(*constantType) || std::get<PrimitiveType>(*constantType) != PrimitiveType::Boolean)
continue; continue;
bool cValue = std::get<bool>(constant.value); bool cValue = std::get<bool>(constant.value);
@ -1017,8 +1025,12 @@ namespace Nz::ShaderAst
if (!m_options.constantQueryCallback) if (!m_options.constantQueryCallback)
return AstCloner::Clone(node); return AstCloner::Clone(node);
auto constant = ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId)); const ConstantValue* constantValue = m_options.constantQueryCallback(node.constantId);
constant->cachedExpressionType = GetExpressionType(constant->value); if (!constantValue)
return AstCloner::Clone(node);
auto constant = ShaderBuilder::Constant(*constantValue);
constant->cachedExpressionType = GetConstantType(constant->value);
return constant; return constant;
} }
@ -1155,7 +1167,7 @@ namespace Nz::ShaderAst
}, lhs.value); }, lhs.value);
if (optimized) if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value); optimized->cachedExpressionType = GetConstantType(optimized->value);
return optimized; return optimized;
} }
@ -1221,7 +1233,7 @@ namespace Nz::ShaderAst
}, operand.value); }, operand.value);
if (optimized) if (optimized)
optimized->cachedExpressionType = GetExpressionType(optimized->value); optimized->cachedExpressionType = GetConstantType(optimized->value);
return optimized; return optimized;
} }

View File

@ -109,6 +109,11 @@ namespace Nz::ShaderAst
node.expression->Visit(*this); node.expression->Visit(*this);
} }
void AstRecursiveVisitor::Visit(TypeExpression& node)
{
/* Nothing to do */
}
void AstRecursiveVisitor::Visit(VariableValueExpression& /*node*/) void AstRecursiveVisitor::Visit(VariableValueExpression& /*node*/)
{ {
/* Nothing to do */ /* Nothing to do */

View File

@ -174,6 +174,11 @@ namespace Nz::ShaderAst
SizeT(node.structTypeId); SizeT(node.structTypeId);
} }
void AstSerializerBase::Serialize(TypeExpression& node)
{
SizeT(node.typeId);
}
void AstSerializerBase::Serialize(FunctionExpression& node) void AstSerializerBase::Serialize(FunctionExpression& node)
{ {
SizeT(node.funcId); SizeT(node.funcId);

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp // For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/Ast/AstUtils.hpp> #include <Nazara/Shader/Ast/AstUtils.hpp>
#include <cassert>
#include <Nazara/Shader/Debug.hpp> #include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderAst namespace Nz::ShaderAst
@ -104,7 +105,10 @@ namespace Nz::ShaderAst
void ShaderAstValueCategory::Visit(SwizzleExpression& node) void ShaderAstValueCategory::Visit(SwizzleExpression& node)
{ {
if (IsPrimitiveType(GetExpressionType(node)) && node.componentCount > 1) const ExpressionType* exprType = GetExpressionType(node);
assert(exprType);
if (IsPrimitiveType(*exprType) && node.componentCount > 1)
// Swizzling more than a component on a primitive produces a rvalue (a.xxxx cannot be assigned) // Swizzling more than a component on a primitive produces a rvalue (a.xxxx cannot be assigned)
m_expressionCategory = ExpressionCategory::RValue; m_expressionCategory = ExpressionCategory::RValue;
else else
@ -133,6 +137,11 @@ namespace Nz::ShaderAst
} }
} }
void ShaderAstValueCategory::Visit(TypeExpression& /*node*/)
{
m_expressionCategory = ExpressionCategory::LValue;
}
void ShaderAstValueCategory::Visit(VariableValueExpression& /*node*/) void ShaderAstValueCategory::Visit(VariableValueExpression& /*node*/)
{ {
m_expressionCategory = ExpressionCategory::LValue; m_expressionCategory = ExpressionCategory::LValue;

View File

@ -8,7 +8,7 @@
namespace Nz::ShaderAst namespace Nz::ShaderAst
{ {
ExpressionType GetExpressionType(const ConstantValue& constant) ExpressionType GetConstantType(const ConstantValue& constant)
{ {
return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType return std::visit([&](auto&& arg) -> ShaderAst::ExpressionType
{ {

View File

@ -15,10 +15,11 @@ namespace Nz::ShaderAst
void DependencyCheckerVisitor::Visit(CallFunctionExpression& node) void DependencyCheckerVisitor::Visit(CallFunctionExpression& node)
{ {
const auto& targetFuncType = GetExpressionType(*node.targetFunction); const ExpressionType* targetFuncType = GetExpressionType(*node.targetFunction);
assert(std::holds_alternative<FunctionType>(targetFuncType)); assert(targetFuncType);
assert(std::holds_alternative<FunctionType>(*targetFuncType));
const auto& funcType = std::get<FunctionType>(targetFuncType); const auto& funcType = std::get<FunctionType>(*targetFuncType);
assert(m_currentFunctionIndex); assert(m_currentFunctionIndex);
if (m_currentVariableDeclIndex) if (m_currentVariableDeclIndex)

File diff suppressed because it is too large Load Diff

View File

@ -36,7 +36,7 @@ namespace Nz
AstRecursiveVisitor::Visit(node); AstRecursiveVisitor::Visit(node);
assert(currentFunction); assert(currentFunction);
currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex); currentFunction->calledFunctions.UnboundedSet(std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex);
} }
void Visit(ShaderAst::ConditionalExpression& /*node*/) override void Visit(ShaderAst::ConditionalExpression& /*node*/) override
@ -307,7 +307,7 @@ namespace Nz
Append(type.GetResultingValue()); Append(type.GetResultingValue());
} }
void GlslWriter::Append(const ShaderAst::FunctionType& functionType) void GlslWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
{ {
throw std::runtime_error("unexpected FunctionType"); throw std::runtime_error("unexpected FunctionType");
} }
@ -829,8 +829,9 @@ namespace Nz
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
assert(IsStructType(exprType)); assert(exprType);
assert(IsStructType(*exprType));
for (const std::string& identifier : node.identifiers) for (const std::string& identifier : node.identifiers)
Append(".", identifier); Append(".", identifier);
@ -840,8 +841,9 @@ namespace Nz
{ {
Visit(node.expr, true); Visit(node.expr, true);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expr); const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expr);
assert(!IsStructType(exprType)); assert(exprType);
assert(!IsStructType(*exprType));
// Array access // Array access
assert(node.indices.size() == 1); assert(node.indices.size() == 1);
@ -1326,9 +1328,10 @@ namespace Nz
{ {
assert(node.returnExpr); assert(node.returnExpr);
const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr); const ShaderAst::ExpressionType* returnType = GetExpressionType(*node.returnExpr);
assert(IsStructType(returnType)); assert(returnType);
std::size_t structIndex = std::get<ShaderAst::StructType>(returnType).structIndex; assert(IsStructType(*returnType));
std::size_t structIndex = std::get<ShaderAst::StructType>(*returnType).structIndex;
const auto& structData = Retrieve(m_currentState->structs, structIndex); const auto& structData = Retrieve(m_currentState->structs, structIndex);
std::string outputStructVarName; std::string outputStructVarName;

View File

@ -182,7 +182,7 @@ namespace Nz
type.GetExpression()->Visit(*this); type.GetExpression()->Visit(*this);
} }
void LangWriter::Append(const ShaderAst::FunctionType& functionType) void LangWriter::Append(const ShaderAst::FunctionType& /*functionType*/)
{ {
throw std::runtime_error("unexpected function type"); throw std::runtime_error("unexpected function type");
} }

View File

@ -263,7 +263,7 @@ namespace Nz::ShaderLang
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" }; throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr); auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String }) if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" }; throw AttributeError{ "attribute " + std::string("nzsl_version") + " expect a single string parameter" };
const std::string& versionStr = std::get<std::string>(constantValue.value); const std::string& versionStr = std::get<std::string>(constantValue.value);
@ -302,7 +302,7 @@ namespace Nz::ShaderLang
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" }; throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr); auto& constantValue = SafeCast<ShaderAst::ConstantValueExpression&>(*expr);
if (ShaderAst::GetExpressionType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String }) if (ShaderAst::GetConstantType(constantValue.value) != ShaderAst::ExpressionType{ ShaderAst::PrimitiveType::String })
throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" }; throw AttributeError{ "attribute " + std::string("uuid") + " expect a single string parameter" };
const std::string& uuidStr = std::get<std::string>(constantValue.value); const std::string& uuidStr = std::get<std::string>(constantValue.value);

View File

@ -67,9 +67,9 @@ namespace Nz
throw std::runtime_error("unexpected type"); throw std::runtime_error("unexpected type");
}; };
const ShaderAst::ExpressionType& resultType = GetExpressionType(node); const ShaderAst::ExpressionType& resultType = *GetExpressionType(node);
const ShaderAst::ExpressionType& leftType = GetExpressionType(*node.left); const ShaderAst::ExpressionType& leftType = *GetExpressionType(*node.left);
const ShaderAst::ExpressionType& rightType = GetExpressionType(*node.right); const ShaderAst::ExpressionType& rightType = *GetExpressionType(*node.right);
ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType); ShaderAst::PrimitiveType leftTypeBase = RetrieveBaseType(leftType);
//ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType); //ShaderAst::PrimitiveType rightTypeBase = RetrieveBaseType(rightType);
@ -405,7 +405,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node) void SpirvAstVisitor::Visit(ShaderAst::CallFunctionExpression& node)
{ {
std::size_t functionIndex = std::get<ShaderAst::FunctionType>(GetExpressionType(*node.targetFunction)).funcIndex; std::size_t functionIndex = std::get<ShaderAst::FunctionType>(*GetExpressionType(*node.targetFunction)).funcIndex;
UInt32 funcId = 0; UInt32 funcId = 0;
for (const auto& [funcIndex, func] : m_funcData) for (const auto& [funcIndex, func] : m_funcData)
@ -434,7 +434,7 @@ namespace Nz
UInt32 resultId = AllocateResultId(); UInt32 resultId = AllocateResultId();
m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender) m_currentBlock->AppendVariadic(SpirvOp::OpFunctionCall, [&](auto&& appender)
{ {
appender(m_writer.GetTypeId(ShaderAst::GetExpressionType(node))); appender(m_writer.GetTypeId(*ShaderAst::GetExpressionType(node)));
appender(resultId); appender(resultId);
appender(funcId); appender(funcId);
@ -718,9 +718,11 @@ namespace Nz
{ {
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(parameterType)); assert(parameterType);
UInt32 typeId = m_writer.GetTypeId(parameterType); assert(IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]); UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]); UInt32 secondParam = EvaluateExpression(node.parameters[1]);
@ -733,10 +735,11 @@ namespace Nz
case ShaderAst::IntrinsicType::DotProduct: case ShaderAst::IntrinsicType::DotProduct:
{ {
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType)); assert(vecExprType);
assert(IsVectorType(*vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType); const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
UInt32 typeId = m_writer.GetTypeId(vecType.type); UInt32 typeId = m_writer.GetTypeId(vecType.type);
@ -754,9 +757,10 @@ namespace Nz
{ {
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType)); assert(parameterType);
UInt32 typeId = m_writer.GetTypeId(parameterType); assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 param = EvaluateExpression(node.parameters[0]); UInt32 param = EvaluateExpression(node.parameters[0]);
UInt32 resultId = m_writer.AllocateResultId(); UInt32 resultId = m_writer.AllocateResultId();
@ -770,10 +774,11 @@ namespace Nz
{ {
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType)); assert(vecExprType);
assert(IsVectorType(*vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType); const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
UInt32 typeId = m_writer.GetTypeId(vecType.type); UInt32 typeId = m_writer.GetTypeId(vecType.type);
UInt32 vec = EvaluateExpression(node.parameters[0]); UInt32 vec = EvaluateExpression(node.parameters[0]);
@ -790,15 +795,16 @@ namespace Nz
{ {
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType)); assert(parameterType);
UInt32 typeId = m_writer.GetTypeId(parameterType); assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
ShaderAst::PrimitiveType basicType; ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(parameterType)) if (IsPrimitiveType(*parameterType))
basicType = std::get<ShaderAst::PrimitiveType>(parameterType); basicType = std::get<ShaderAst::PrimitiveType>(*parameterType);
else if (IsVectorType(parameterType)) else if (IsVectorType(*parameterType))
basicType = std::get<ShaderAst::VectorType>(parameterType).type; basicType = std::get<ShaderAst::VectorType>(*parameterType).type;
else else
throw std::runtime_error("unexpected expression type"); throw std::runtime_error("unexpected expression type");
@ -837,10 +843,11 @@ namespace Nz
{ {
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& vecExprType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* vecExprType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(vecExprType)); assert(vecExprType);
assert(IsVectorType(*vecExprType));
const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(vecExprType); const ShaderAst::VectorType& vecType = std::get<ShaderAst::VectorType>(*vecExprType);
UInt32 typeId = m_writer.GetTypeId(vecType); UInt32 typeId = m_writer.GetTypeId(vecType);
UInt32 vec = EvaluateExpression(node.parameters[0]); UInt32 vec = EvaluateExpression(node.parameters[0]);
@ -856,9 +863,10 @@ namespace Nz
{ {
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(IsPrimitiveType(parameterType) || IsVectorType(parameterType)); assert(parameterType);
UInt32 typeId = m_writer.GetTypeId(parameterType); assert(IsPrimitiveType(*parameterType) || IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]); UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]); UInt32 secondParam = EvaluateExpression(node.parameters[1]);
@ -873,9 +881,10 @@ namespace Nz
{ {
UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450"); UInt32 glslInstructionSet = m_writer.GetExtendedInstructionSet("GLSL.std.450");
const ShaderAst::ExpressionType& parameterType = GetExpressionType(*node.parameters[0]); const ShaderAst::ExpressionType* parameterType = GetExpressionType(*node.parameters[0]);
assert(IsVectorType(parameterType)); assert(parameterType);
UInt32 typeId = m_writer.GetTypeId(parameterType); assert(IsVectorType(*parameterType));
UInt32 typeId = m_writer.GetTypeId(*parameterType);
UInt32 firstParam = EvaluateExpression(node.parameters[0]); UInt32 firstParam = EvaluateExpression(node.parameters[0]);
UInt32 secondParam = EvaluateExpression(node.parameters[1]); UInt32 secondParam = EvaluateExpression(node.parameters[1]);
@ -951,20 +960,22 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node) void SpirvAstVisitor::Visit(ShaderAst::SwizzleExpression& node)
{ {
const ShaderAst::ExpressionType& swizzledExpressionType = GetExpressionType(*node.expression); const ShaderAst::ExpressionType* swizzledExpressionType = GetExpressionType(*node.expression);
assert(swizzledExpressionType);
UInt32 exprResultId = EvaluateExpression(node.expression); UInt32 exprResultId = EvaluateExpression(node.expression);
const ShaderAst::ExpressionType& targetExprType = GetExpressionType(node); const ShaderAst::ExpressionType* targetExprType = GetExpressionType(node);
assert(targetExprType);
if (node.componentCount > 1) if (node.componentCount > 1)
{ {
assert(IsVectorType(targetExprType)); assert(IsVectorType(*targetExprType));
const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(targetExprType); const ShaderAst::VectorType& targetType = std::get<ShaderAst::VectorType>(*targetExprType);
UInt32 resultId = m_writer.AllocateResultId(); UInt32 resultId = m_writer.AllocateResultId();
if (IsVectorType(swizzledExpressionType)) if (IsVectorType(*swizzledExpressionType))
{ {
// Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands // Swizzling a vector is implemented via OpVectorShuffle using the same vector twice as operands
m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender) m_currentBlock->AppendVariadic(SpirvOp::OpVectorShuffle, [&](const auto& appender)
@ -980,7 +991,7 @@ namespace Nz
} }
else else
{ {
assert(IsPrimitiveType(swizzledExpressionType)); assert(IsPrimitiveType(*swizzledExpressionType));
// Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct // Swizzling a primitive to a vector (a.xxx) can be implemented using OpCompositeConstruct
m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender) m_currentBlock->AppendVariadic(SpirvOp::OpCompositeConstruct, [&](const auto& appender)
@ -995,10 +1006,10 @@ namespace Nz
PushResultId(resultId); PushResultId(resultId);
} }
else if (IsVectorType(swizzledExpressionType)) else if (IsVectorType(*swizzledExpressionType))
{ {
assert(IsPrimitiveType(targetExprType)); assert(IsPrimitiveType(*targetExprType));
ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(targetExprType); ShaderAst::PrimitiveType targetType = std::get<ShaderAst::PrimitiveType>(*targetExprType);
// Extract a single component from the vector // Extract a single component from the vector
assert(node.componentCount == 1); assert(node.componentCount == 1);
@ -1011,8 +1022,8 @@ namespace Nz
else else
{ {
// Swizzling a primitive to itself (a.x for example), don't do anything // Swizzling a primitive to itself (a.x for example), don't do anything
assert(IsPrimitiveType(swizzledExpressionType)); assert(IsPrimitiveType(*swizzledExpressionType));
assert(IsPrimitiveType(targetExprType)); assert(IsPrimitiveType(*targetExprType));
assert(node.componentCount == 1); assert(node.componentCount == 1);
assert(node.components[0] == 0); assert(node.components[0] == 0);
@ -1022,8 +1033,11 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node) void SpirvAstVisitor::Visit(ShaderAst::UnaryExpression& node)
{ {
const ShaderAst::ExpressionType& resultType = GetExpressionType(node); const ShaderAst::ExpressionType* resultType = GetExpressionType(node);
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node.expression); assert(resultType);
const ShaderAst::ExpressionType* exprType = GetExpressionType(*node.expression);
assert(exprType);
UInt32 operand = EvaluateExpression(node.expression); UInt32 operand = EvaluateExpression(node.expression);
@ -1033,11 +1047,11 @@ namespace Nz
{ {
case ShaderAst::UnaryType::LogicalNot: case ShaderAst::UnaryType::LogicalNot:
{ {
assert(IsPrimitiveType(exprType)); assert(IsPrimitiveType(*exprType));
assert(std::get<ShaderAst::PrimitiveType>(resultType) == ShaderAst::PrimitiveType::Boolean); assert(std::get<ShaderAst::PrimitiveType>(*resultType) == ShaderAst::PrimitiveType::Boolean);
UInt32 resultId = m_writer.AllocateResultId(); UInt32 resultId = m_writer.AllocateResultId();
m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(resultType), resultId, operand); m_currentBlock->Append(SpirvOp::OpLogicalNot, m_writer.GetTypeId(*resultType), resultId, operand);
return resultId; return resultId;
} }
@ -1045,10 +1059,10 @@ namespace Nz
case ShaderAst::UnaryType::Minus: case ShaderAst::UnaryType::Minus:
{ {
ShaderAst::PrimitiveType basicType; ShaderAst::PrimitiveType basicType;
if (IsPrimitiveType(exprType)) if (IsPrimitiveType(*exprType))
basicType = std::get<ShaderAst::PrimitiveType>(exprType); basicType = std::get<ShaderAst::PrimitiveType>(*exprType);
else if (IsVectorType(exprType)) else if (IsVectorType(*exprType))
basicType = std::get<ShaderAst::VectorType>(exprType).type; basicType = std::get<ShaderAst::VectorType>(*exprType).type;
else else
throw std::runtime_error("unexpected expression type"); throw std::runtime_error("unexpected expression type");
@ -1057,12 +1071,12 @@ namespace Nz
switch (basicType) switch (basicType)
{ {
case ShaderAst::PrimitiveType::Float32: case ShaderAst::PrimitiveType::Float32:
m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(resultType), resultId, operand); m_currentBlock->Append(SpirvOp::OpFNegate, m_writer.GetTypeId(*resultType), resultId, operand);
return resultId; return resultId;
case ShaderAst::PrimitiveType::Int32: case ShaderAst::PrimitiveType::Int32:
case ShaderAst::PrimitiveType::UInt32: case ShaderAst::PrimitiveType::UInt32:
m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(resultType), resultId, operand); m_currentBlock->Append(SpirvOp::OpSNegate, m_writer.GetTypeId(*resultType), resultId, operand);
return resultId; return resultId;
default: default:

View File

@ -76,9 +76,10 @@ namespace Nz
{ {
node.expr->Visit(*this); node.expr->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node); const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
assert(exprType);
UInt32 typeId = m_writer.GetTypeId(exprType); UInt32 typeId = m_writer.GetTypeId(*exprType);
assert(node.indices.size() == 1); assert(node.indices.size() == 1);
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front()); UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
@ -88,7 +89,7 @@ namespace Nz
[&](const Pointer& pointer) [&](const Pointer& pointer)
{ {
PointerChainAccess pointerChainAccess; PointerChainAccess pointerChainAccess;
pointerChainAccess.exprType = &exprType; pointerChainAccess.exprType = exprType;
pointerChainAccess.indices = { indexId }; pointerChainAccess.indices = { indexId };
pointerChainAccess.pointedTypeId = pointer.pointedTypeId; pointerChainAccess.pointedTypeId = pointer.pointedTypeId;
pointerChainAccess.pointerId = pointer.pointerId; pointerChainAccess.pointerId = pointer.pointerId;
@ -98,7 +99,7 @@ namespace Nz
}, },
[&](PointerChainAccess& pointerChainAccess) [&](PointerChainAccess& pointerChainAccess)
{ {
pointerChainAccess.exprType = &exprType; pointerChainAccess.exprType = exprType;
pointerChainAccess.indices.push_back(indexId); pointerChainAccess.indices.push_back(indexId);
}, },
[&](const Value& value) [&](const Value& value)

View File

@ -8,6 +8,7 @@
#include <Nazara/Shader/SpirvAstVisitor.hpp> #include <Nazara/Shader/SpirvAstVisitor.hpp>
#include <Nazara/Shader/SpirvBlock.hpp> #include <Nazara/Shader/SpirvBlock.hpp>
#include <Nazara/Shader/SpirvWriter.hpp> #include <Nazara/Shader/SpirvWriter.hpp>
#include <cassert>
#include <numeric> #include <numeric>
#include <Nazara/Shader/Debug.hpp> #include <Nazara/Shader/Debug.hpp>
@ -61,11 +62,12 @@ namespace Nz
} }
else else
{ {
const ShaderAst::ExpressionType& exprType = GetExpressionType(*node);
assert(swizzledPointer.componentCount == 1); assert(swizzledPointer.componentCount == 1);
UInt32 pointerType = m_writer.RegisterPointerType(exprType, swizzledPointer.storage); //< FIXME const ShaderAst::ExpressionType* exprType = GetExpressionType(*node);
assert(exprType);
UInt32 pointerType = m_writer.RegisterPointerType(*exprType, swizzledPointer.storage); //< FIXME
// Access chain // Access chain
UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0])); UInt32 indexId = m_writer.GetConstantId(SafeCast<Int32>(swizzledPointer.swizzleIndices[0]));
@ -86,14 +88,15 @@ namespace Nz
{ {
node.expr->Visit(*this); node.expr->Visit(*this);
const ShaderAst::ExpressionType& exprType = GetExpressionType(node); const ShaderAst::ExpressionType* exprType = GetExpressionType(node);
assert(exprType);
std::visit(Overloaded std::visit(Overloaded
{ {
[&](const Pointer& pointer) [&](const Pointer& pointer)
{ {
UInt32 resultId = m_visitor.AllocateResultId(); UInt32 resultId = m_visitor.AllocateResultId();
UInt32 pointerType = m_writer.RegisterPointerType(exprType, pointer.storage); //< FIXME UInt32 pointerType = m_writer.RegisterPointerType(*exprType, pointer.storage); //< FIXME
assert(node.indices.size() == 1); assert(node.indices.size() == 1);
UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front()); UInt32 indexId = m_visitor.EvaluateExpression(node.indices.front());
@ -117,13 +120,14 @@ namespace Nz
{ {
[&](const Pointer& pointer) [&](const Pointer& pointer)
{ {
const auto& expressionType = GetExpressionType(*node.expression); const ShaderAst::ExpressionType* expressionType = GetExpressionType(*node.expression);
assert(IsVectorType(expressionType)); assert(expressionType);
assert(IsVectorType(*expressionType));
SwizzledPointer swizzledPointer; SwizzledPointer swizzledPointer;
swizzledPointer.pointerId = pointer.pointerId; swizzledPointer.pointerId = pointer.pointerId;
swizzledPointer.storage = pointer.storage; swizzledPointer.storage = pointer.storage;
swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(expressionType); swizzledPointer.swizzledType = std::get<ShaderAst::VectorType>(*expressionType);
swizzledPointer.componentCount = node.componentCount; swizzledPointer.componentCount = node.componentCount;
swizzledPointer.swizzleIndices = node.components; swizzledPointer.swizzleIndices = node.components;

View File

@ -98,7 +98,7 @@ namespace Nz
for (const auto& parameter : node.parameters) for (const auto& parameter : node.parameters)
{ {
auto& var = func.variables.emplace_back(); auto& var = func.variables.emplace_back();
var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(GetExpressionType(*parameter), SpirvStorageClass::Function)); var.typeId = m_constantCache.Register(*m_constantCache.BuildPointerType(*GetExpressionType(*parameter), SpirvStorageClass::Function));
} }
} }