Shader: Add function parameters and return handling

This commit is contained in:
Jérôme Leclercq
2021-02-28 17:50:32 +01:00
parent 9a0f201433
commit b320b5b44e
39 changed files with 818 additions and 327 deletions

View File

@@ -608,6 +608,18 @@ namespace Nz
Append(var.name);
}
void GlslWriter::Visit(ShaderNodes::ReturnStatement& node)
{
if (node.returnExpr)
{
Append("return ");
Visit(node.returnExpr);
Append(";");
}
else
Append("return;");
}
void GlslWriter::Visit(ShaderNodes::OutputVariable& var)
{
Append(var.name);

View File

@@ -13,7 +13,7 @@ namespace Nz
conditionEntry.name = std::move(name);
}
void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters, ShaderNodes::BasicType returnType)
void ShaderAst::AddFunction(std::string name, ShaderNodes::StatementPtr statement, std::vector<FunctionParameter> parameters, ShaderExpressionType returnType)
{
auto& functionEntry = m_functions.emplace_back();
functionEntry.name = std::move(name);

View File

@@ -142,6 +142,11 @@ namespace Nz
PushStatement(ShaderNodes::NoOp::Build());
}
void ShaderAstCloner::Visit(ShaderNodes::ReturnStatement& node)
{
PushStatement(ShaderNodes::ReturnStatement::Build(CloneExpression(node.returnExpr)));
}
void ShaderAstCloner::Visit(ShaderNodes::Sample2D& node)
{
PushExpression(ShaderNodes::Sample2D::Build(CloneExpression(node.sampler), CloneExpression(node.coordinates)));

View File

@@ -95,6 +95,12 @@ namespace Nz
/* Nothing to do */
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::ReturnStatement& node)
{
if (node.returnExpr)
Visit(node.returnExpr);
}
void ShaderAstRecursiveVisitor::Visit(ShaderNodes::Sample2D& node)
{
Visit(node.sampler);

View File

@@ -92,6 +92,11 @@ namespace Nz
Serialize(node);
}
void Visit(ShaderNodes::ReturnStatement& node) override
{
Serialize(node);
}
void Visit(ShaderNodes::Sample2D& node) override
{
Serialize(node);
@@ -286,6 +291,11 @@ namespace Nz
/* Nothing to do */
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::ReturnStatement& node)
{
Node(node.returnExpr);
}
void ShaderAstSerializerBase::Serialize(ShaderNodes::Sample2D& node)
{
Node(node.sampler);
@@ -391,7 +401,8 @@ namespace Nz
m_stream << UInt32(shader.GetFunctionCount());
for (const auto& func : shader.GetFunctions())
{
m_stream << func.name << UInt32(func.returnType);
m_stream << func.name;
SerializeType(func.returnType);
m_stream << UInt32(func.parameters.size());
for (const auto& param : func.parameters)
@@ -634,11 +645,12 @@ namespace Nz
for (UInt32 i = 0; i < funcCount; ++i)
{
std::string name;
ShaderNodes::BasicType retType;
ShaderExpressionType retType;
std::vector<ShaderAst::FunctionParameter> parameters;
Value(name);
Enum(retType);
Type(retType);
Container(parameters);
for (auto& param : parameters)
{
@@ -653,7 +665,7 @@ namespace Nz
ShaderNodes::StatementPtr statement = std::static_pointer_cast<ShaderNodes::Statement>(node);
shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), retType);
shader.AddFunction(std::move(name), std::move(statement), std::move(parameters), std::move(retType));
}
return shader;
@@ -693,6 +705,7 @@ namespace Nz
HandleType(Identifier);
HandleType(IntrinsicCall);
HandleType(NoOp);
HandleType(ReturnStatement);
HandleType(Sample2D);
HandleType(SwizzleOp);
HandleType(StatementBlock);

View File

@@ -349,6 +349,22 @@ namespace Nz
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::ReturnStatement& node)
{
if (m_context->currentFunction->returnType != ShaderExpressionType(ShaderNodes::BasicType::Void))
{
if (MandatoryExpr(node.returnExpr)->GetExpressionType() != m_context->currentFunction->returnType)
throw AstError{ "Return type doesn't match function return type" };
}
else
{
if (node.returnExpr)
throw AstError{ "Unexpected expression for return (function doesn't return)" };
}
ShaderAstRecursiveVisitor::Visit(node);
}
void ShaderAstValidator::Visit(ShaderNodes::Sample2D& node)
{
if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D })

View File

@@ -78,6 +78,11 @@ namespace Nz
throw std::runtime_error("unhandled NoOp node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::ReturnStatement& node)
{
throw std::runtime_error("unhandled ReturnStatement node");
}
void ShaderAstVisitorExcept::Visit(ShaderNodes::Sample2D& /*node*/)
{
throw std::runtime_error("unhandled Sample2D node");

View File

@@ -103,7 +103,7 @@ namespace Nz::ShaderLang
break;
}
tokenType = TokenType::Subtract;
tokenType = TokenType::Minus;
break;
}
@@ -193,7 +193,7 @@ namespace Nz::ShaderLang
char* end;
double value = std::strtod(valueStr.c_str(), &end);
if (end != &str[currentPos])
if (end != &str[currentPos + 1])
throw BadNumber{};
token.data = value;
@@ -204,7 +204,7 @@ namespace Nz::ShaderLang
long long value;
std::from_chars_result r = std::from_chars(&str[start], &str[currentPos + 1], value);
if (r.ptr != &str[currentPos])
if (r.ptr != &str[currentPos + 1])
{
if (r.ec == std::errc::result_out_of_range)
throw NumberOutOfRange{};
@@ -218,7 +218,7 @@ namespace Nz::ShaderLang
break;
}
case '+': tokenType = TokenType::Add; break;
case '+': tokenType = TokenType::Plus; break;
case '*': tokenType = TokenType::Multiply; break;
case ':': tokenType = TokenType::Colon; break;
case ';': tokenType = TokenType::Semicolon; break;

View File

@@ -8,7 +8,33 @@
namespace Nz::ShaderLang
{
void Parser::Parse(const std::vector<Token>& tokens)
namespace
{
std::unordered_map<std::string, ShaderNodes::BasicType> identifierToBasicType = {
{ "bool", ShaderNodes::BasicType::Boolean },
{ "i32", ShaderNodes::BasicType::Int1 },
{ "vec2i32", ShaderNodes::BasicType::Int2 },
{ "vec3i32", ShaderNodes::BasicType::Int3 },
{ "vec4i32", ShaderNodes::BasicType::Int4 },
{ "f32", ShaderNodes::BasicType::Float1 },
{ "vec2f32", ShaderNodes::BasicType::Float2 },
{ "vec3f32", ShaderNodes::BasicType::Float3 },
{ "vec4f32", ShaderNodes::BasicType::Float4 },
{ "mat4x4f32", ShaderNodes::BasicType::Mat4x4 },
{ "sampler2D", ShaderNodes::BasicType::Sampler2D },
{ "void", ShaderNodes::BasicType::Void },
{ "u32", ShaderNodes::BasicType::UInt1 },
{ "vec2u32", ShaderNodes::BasicType::UInt3 },
{ "vec3u32", ShaderNodes::BasicType::UInt3 },
{ "vec4u32", ShaderNodes::BasicType::UInt4 },
};
}
ShaderAst Parser::Parse(const std::vector<Token>& tokens)
{
Context context;
context.tokenCount = tokens.size();
@@ -16,18 +42,28 @@ namespace Nz::ShaderLang
m_context = &context;
for (const Token& token : tokens)
m_context->tokenIndex = -1;
bool reachedEndOfStream = false;
while (!reachedEndOfStream)
{
switch (token.type)
const Token& nextToken = PeekNext();
switch (nextToken.type)
{
case TokenType::FunctionDeclaration:
ParseFunctionDeclaration();
break;
case TokenType::EndOfStream:
reachedEndOfStream = true;
break;
default:
throw UnexpectedToken{};
}
}
return std::move(context.result);
}
const Token& Parser::Advance()
@@ -42,24 +78,34 @@ namespace Nz::ShaderLang
throw ExpectedToken{};
}
void Parser::ExpectNext(TokenType type)
const Token& Parser::ExpectNext(TokenType type)
{
Expect(m_context->tokens[m_context->tokenIndex + 1], type);
const Token& token = Advance();
Expect(token, type);
return token;
}
void Parser::ParseFunctionBody()
const Token& Parser::PeekNext()
{
assert(m_context->tokenIndex + 1 < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + 1];
}
ShaderNodes::StatementPtr Parser::ParseFunctionBody()
{
return ParseStatementList();
}
void Parser::ParseFunctionDeclaration()
{
ExpectNext(TokenType::Identifier);
ExpectNext(TokenType::FunctionDeclaration);
std::string functionName = std::get<std::string>(Advance().data);
std::string functionName = ParseIdentifierAsName();
ExpectNext(TokenType::OpenParenthesis);
Advance();
std::vector<ShaderAst::FunctionParameter> parameters;
bool firstParameter = true;
for (;;)
@@ -74,45 +120,192 @@ namespace Nz::ShaderLang
Advance();
}
ParseFunctionParameter();
parameters.push_back(ParseFunctionParameter());
firstParameter = false;
}
ExpectNext(TokenType::ClosingParenthesis);
Advance();
ShaderExpressionType returnType = ShaderNodes::BasicType::Void;
if (PeekNext().type == TokenType::FunctionReturn)
{
Advance();
Advance(); //< Consume ->
std::string returnType = std::get<std::string>(Advance().data);
returnType = ParseIdentifierAsType();
}
ExpectNext(TokenType::OpenCurlyBracket);
Advance();
ParseFunctionBody();
ShaderNodes::StatementPtr functionBody = ParseFunctionBody();
ExpectNext(TokenType::ClosingCurlyBracket);
Advance();
m_context->result.AddFunction(functionName, functionBody, std::move(parameters), returnType);
}
void Parser::ParseFunctionParameter()
ShaderAst::FunctionParameter Parser::ParseFunctionParameter()
{
ExpectNext(TokenType::Identifier);
std::string parameterName = std::get<std::string>(Advance().data);
std::string parameterName = ParseIdentifierAsName();
ExpectNext(TokenType::Colon);
Advance();
ExpectNext(TokenType::Identifier);
std::string parameterType = std::get<std::string>(Advance().data);
ShaderExpressionType parameterType = ParseIdentifierAsType();
return { parameterName, parameterType };
}
const Token& Parser::PeekNext()
ShaderNodes::StatementPtr Parser::ParseReturnStatement()
{
assert(m_context->tokenIndex + 1 < m_context->tokenCount);
return m_context->tokens[m_context->tokenIndex + 1];
ExpectNext(TokenType::Return);
ShaderNodes::ExpressionPtr expr;
if (PeekNext().type != TokenType::Semicolon)
expr = ParseExpression();
return ShaderNodes::ReturnStatement::Build(std::move(expr));
}
ShaderNodes::StatementPtr Parser::ParseStatement()
{
const Token& token = PeekNext();
ShaderNodes::StatementPtr statement;
switch (token.type)
{
case TokenType::Return:
statement = ParseReturnStatement();
break;
default:
break;
}
ExpectNext(TokenType::Semicolon);
return statement;
}
ShaderNodes::StatementPtr Parser::ParseStatementList()
{
std::vector<ShaderNodes::StatementPtr> statements;
while (PeekNext().type != TokenType::ClosingCurlyBracket)
{
statements.push_back(ParseStatement());
}
return ShaderNodes::StatementBlock::Build(std::move(statements));
}
ShaderNodes::ExpressionPtr Parser::ParseBinOpRhs(int exprPrecedence, ShaderNodes::ExpressionPtr lhs)
{
for (;;)
{
const Token& currentOp = PeekNext();
int tokenPrecedence = GetTokenPrecedence(currentOp.type);
if (tokenPrecedence < exprPrecedence)
return lhs;
Advance();
ShaderNodes::ExpressionPtr rhs = ParsePrimaryExpression();
const Token& nextOp = PeekNext();
int nextTokenPrecedence = GetTokenPrecedence(nextOp.type);
if (tokenPrecedence < nextTokenPrecedence)
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
ShaderNodes::BinaryType binaryType;
{
switch (currentOp.type)
{
case TokenType::Plus: binaryType = ShaderNodes::BinaryType::Add; break;
case TokenType::Minus: binaryType = ShaderNodes::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderNodes::BinaryType::Multiply; break;
case TokenType::Divide: binaryType = ShaderNodes::BinaryType::Divide; break;
default: throw UnexpectedToken{};
}
}
lhs = ShaderNodes::BinaryOp::Build(binaryType, std::move(lhs), std::move(rhs));
}
}
ShaderNodes::ExpressionPtr Parser::ParseExpression()
{
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
ShaderNodes::ExpressionPtr Parser::ParseIdentifier()
{
const Token& identifier = ExpectNext(TokenType::Identifier);
return ShaderNodes::Identifier::Build(ShaderNodes::ParameterVariable::Build(std::get<std::string>(identifier.data), ShaderNodes::BasicType::Float3));
}
ShaderNodes::ExpressionPtr Parser::ParseIntegerExpression()
{
const Token& integer = ExpectNext(TokenType::IntegerValue);
return ShaderNodes::Constant::Build(static_cast<Nz::Int32>(std::get<long long>(integer.data)));
}
ShaderNodes::ExpressionPtr Parser::ParseParenthesisExpression()
{
ExpectNext(TokenType::OpenParenthesis);
ShaderNodes::ExpressionPtr expression = ParseExpression();
ExpectNext(TokenType::ClosingParenthesis);
return expression;
}
ShaderNodes::ExpressionPtr Parser::ParsePrimaryExpression()
{
const Token& token = PeekNext();
switch (token.type)
{
case TokenType::BoolFalse: return ShaderNodes::Constant::Build(false);
case TokenType::BoolTrue: return ShaderNodes::Constant::Build(true);
case TokenType::Identifier: return ParseIdentifier();
case TokenType::IntegerValue: return ParseIntegerExpression();
case TokenType::OpenParenthesis: return ParseParenthesisExpression();
default: throw UnexpectedToken{};
}
}
std::string Parser::ParseIdentifierAsName()
{
const Token& identifierToken = ExpectNext(TokenType::Identifier);
std::string identifier = std::get<std::string>(identifierToken.data);
auto it = identifierToBasicType.find(identifier);
if (it != identifierToBasicType.end())
throw ReservedKeyword{};
return identifier;
}
ShaderExpressionType Parser::ParseIdentifierAsType()
{
const Token& identifier = ExpectNext(TokenType::Identifier);
auto it = identifierToBasicType.find(std::get<std::string>(identifier.data));
if (it == identifierToBasicType.end())
throw UnknownType{};
return it->second;
}
int Parser::GetTokenPrecedence(TokenType token)
{
switch (token)
{
case TokenType::Plus: return 20;
case TokenType::Divide: return 40;
case TokenType::Multiply: return 40;
case TokenType::Minus: return 20;
default: return -1;
}
}
}

View File

@@ -69,6 +69,11 @@ namespace Nz::ShaderNodes
visitor.Visit(*this);
}
void ReturnStatement::Visit(ShaderAstVisitor& visitor)
{
visitor.Visit(*this);
}
ShaderExpressionType AssignOp::GetExpressionType() const
{
return left->GetExpressionType();

View File

@@ -611,6 +611,14 @@ namespace Nz
// nothing to do
}
void SpirvAstVisitor::Visit(ShaderNodes::ReturnStatement& node)
{
if (node.returnExpr)
m_currentBlock->Append(SpirvOp::OpReturnValue, EvaluateExpression(node.returnExpr));
else
m_currentBlock->Append(SpirvOp::OpReturn);
}
void SpirvAstVisitor::Visit(ShaderNodes::Sample2D& node)
{
UInt32 typeId = m_writer.GetTypeId(ShaderNodes::BasicType::Float4);

View File

@@ -18,6 +18,7 @@ namespace Nz
template<class... Ts> overloaded(Ts...)->overloaded<Ts...>;
}
struct SpirvConstantCache::Eq
{
bool Compare(const ConstantBool& lhs, const ConstantBool& rhs) const
@@ -353,6 +354,12 @@ namespace Nz
}, v);
}
void Register(const std::vector<TypePtr>& lhs)
{
for (std::size_t i = 0; i < lhs.size(); ++i)
cache.Register(*lhs[i]);
}
template<typename T>
void Register(const std::vector<T>& lhs)
{

View File

@@ -109,6 +109,11 @@ namespace Nz
m_value = Value{ m_writer.ReadLocalVariable(var.name) };
}
void SpirvExpressionLoad::Visit(ShaderNodes::ParameterVariable& var)
{
m_value = Value{ m_writer.ReadParameterVariable(var.name) };
}
void SpirvExpressionLoad::Visit(ShaderNodes::UniformVariable& var)
{
auto uniformVar = m_writer.GetUniformVariable(var.name);

View File

@@ -2,13 +2,13 @@
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/SpirvSection.hpp>
#include <Nazara/Shader/SpirvSectionBase.hpp>
#include <Nazara/Core/Endianness.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
std::size_t SpirvSection::Append(const Raw& raw)
std::size_t SpirvSectionBase::AppendRaw(const Raw& raw)
{
std::size_t offset = GetOutputOffset();
@@ -30,7 +30,7 @@ namespace Nz
codepoint |= UInt32(ptr[pos]) << (j * 8);
}
Append(codepoint);
AppendRaw(codepoint);
}
return offset;

View File

@@ -191,16 +191,17 @@ namespace Nz
{
UInt32 typeId;
UInt32 id;
std::vector<UInt32> paramsId;
};
tsl::ordered_map<std::string, ExtVar> inputIds;
tsl::ordered_map<std::string, ExtVar> outputIds;
tsl::ordered_map<std::string, ExtVar> parameterIds;
tsl::ordered_map<std::string, ExtVar> uniformIds;
std::unordered_map<std::string, UInt32> extensionInstructions;
std::unordered_map<ShaderNodes::BuiltinEntry, ExtVar> builtinIds;
std::unordered_map<std::string, UInt32> varToResult;
std::vector<Func> funcs;
std::vector<SpirvBlock> functionBlocks;
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
@@ -307,7 +308,7 @@ namespace Nz
builtinData.typeId = GetTypeId(builtinType);
builtinData.varId = varId;
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpvDecorationBuiltIn, builtinDecoration);
state.annotations.Append(SpirvOp::OpDecorate, builtinData.varId, SpirvDecoration::BuiltIn, builtinDecoration);
state.builtinIds.emplace(builtin->entry, builtinData);
}
@@ -329,7 +330,7 @@ namespace Nz
state.inputIds.emplace(input.name, std::move(inputData));
if (input.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *input.locationIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *input.locationIndex);
}
for (const auto& output : shader.GetOutputs())
@@ -349,7 +350,7 @@ namespace Nz
state.outputIds.emplace(output.name, std::move(outputData));
if (output.locationIndex)
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationLocation, *output.locationIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Location, *output.locationIndex);
}
for (const auto& uniform : shader.GetUniforms())
@@ -370,8 +371,8 @@ namespace Nz
if (uniform.bindingIndex)
{
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationBinding, *uniform.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpvDecorationDescriptorSet, 0);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::Binding, *uniform.bindingIndex);
state.annotations.Append(SpirvOp::OpDecorate, varId, SpirvDecoration::DescriptorSet, 0);
}
}
@@ -396,77 +397,86 @@ namespace Nz
state.instructions.Append(SpirvOp::OpFunction, GetTypeId(func.returnType), funcData.id, 0, funcData.typeId);
std::vector<SpirvBlock> blocks;
blocks.emplace_back(*this);
state.functionBlocks.clear();
state.functionBlocks.emplace_back(*this);
state.parameterIds.clear();
for (const auto& param : func.parameters)
{
UInt32 paramResultId = AllocateResultId();
funcData.paramsId.push_back(paramResultId);
state.instructions.Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
blocks.back().Append(SpirvOp::OpFunctionParameter, GetTypeId(param.type), paramResultId);
ExtVar parameterData;
parameterData.pointerTypeId = GetPointerTypeId(param.type, SpirvStorageClass::Function);
parameterData.typeId = GetTypeId(param.type);
parameterData.varId = paramResultId;
state.parameterIds.emplace(param.name, std::move(parameterData));
}
SpirvAstVisitor visitor(*this, blocks);
SpirvAstVisitor visitor(*this, state.functionBlocks);
visitor.Visit(functionStatements[funcIndex]);
if (func.returnType == ShaderNodes::BasicType::Void)
blocks.back().Append(SpirvOp::OpReturn);
else
throw std::runtime_error("returning values from functions is not yet supported"); //< TODO
if (!state.functionBlocks.back().IsTerminated())
{
assert(func.returnType == ShaderExpressionType(ShaderNodes::BasicType::Void));
state.functionBlocks.back().Append(SpirvOp::OpReturn);
}
blocks.back().Append(SpirvOp::OpFunctionEnd);
for (SpirvBlock& block : state.functionBlocks)
state.instructions.AppendSection(block);
for (SpirvBlock& block : blocks)
state.instructions.Append(block);
state.instructions.Append(SpirvOp::OpFunctionEnd);
}
assert(entryPointIndex != std::numeric_limits<std::size_t>::max());
m_currentState->constantTypeCache.Write(m_currentState->annotations, m_currentState->constants, m_currentState->debugInfo);
AppendHeader();
SpvExecutionModel execModel;
const auto& entryFuncData = shader.GetFunction(entryPointIndex);
const auto& entryFunc = state.funcs[entryPointIndex];
assert(m_context.shader);
switch (m_context.shader->GetStage())
if (entryPointIndex != std::numeric_limits<std::size_t>::max())
{
case ShaderStageType::Fragment:
execModel = SpvExecutionModelFragment;
break;
SpvExecutionModel execModel;
const auto& entryFuncData = shader.GetFunction(entryPointIndex);
const auto& entryFunc = state.funcs[entryPointIndex];
case ShaderStageType::Vertex:
execModel = SpvExecutionModelVertex;
break;
assert(m_context.shader);
switch (m_context.shader->GetStage())
{
case ShaderStageType::Fragment:
execModel = SpvExecutionModelFragment;
break;
default:
throw std::runtime_error("not yet implemented");
case ShaderStageType::Vertex:
execModel = SpvExecutionModelVertex;
break;
default:
throw std::runtime_error("not yet implemented");
}
// OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(entryFunc.id);
appender(entryFuncData.name);
for (const auto& [name, varData] : state.builtinIds)
appender(varData.varId);
for (const auto& [name, varData] : state.inputIds)
appender(varData.varId);
for (const auto& [name, varData] : state.outputIds)
appender(varData.varId);
});
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
}
// OpEntryPoint Vertex %main "main" %outNormal %inNormals %outTexCoords %inTexCoord %_ %inPos
state.header.AppendVariadic(SpirvOp::OpEntryPoint, [&](const auto& appender)
{
appender(execModel);
appender(entryFunc.id);
appender(entryFuncData.name);
for (const auto& [name, varData] : state.builtinIds)
appender(varData.varId);
for (const auto& [name, varData] : state.inputIds)
appender(varData.varId);
for (const auto& [name, varData] : state.outputIds)
appender(varData.varId);
});
if (m_context.shader->GetStage() == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpvExecutionModeOriginUpperLeft);
std::vector<UInt32> ret;
MergeSections(ret, state.header);
MergeSections(ret, state.debugInfo);
@@ -489,14 +499,14 @@ namespace Nz
void SpirvWriter::AppendHeader()
{
m_currentState->header.Append(SpvMagicNumber); //< Spir-V magic number
m_currentState->header.AppendRaw(SpvMagicNumber); //< Spir-V magic number
UInt32 version = (m_environment.spvMajorVersion << 16) | m_environment.spvMinorVersion << 8;
m_currentState->header.Append(version); //< Spir-V version number (1.0 for compatibility)
m_currentState->header.Append(0); //< Generator identifier (TODO: Register generator to Khronos)
m_currentState->header.AppendRaw(version); //< Spir-V version number (1.0 for compatibility)
m_currentState->header.AppendRaw(0); //< Generator identifier (TODO: Register generator to Khronos)
m_currentState->header.Append(m_currentState->nextVarIndex); //< Bound (ID count)
m_currentState->header.Append(0); //< Instruction schema (required to be 0 for now)
m_currentState->header.AppendRaw(m_currentState->nextVarIndex); //< Bound (ID count)
m_currentState->header.AppendRaw(0); //< Instruction schema (required to be 0 for now)
m_currentState->header.Append(SpirvOp::OpCapability, SpvCapabilityShader);
@@ -506,6 +516,20 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
}
SpirvConstantCache::Function SpirvWriter::BuildFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildPointerType(*m_context.shader, parameter.type, SpirvStorageClass::Function));
return SpirvConstantCache::Function{
SpirvConstantCache::BuildType(*m_context.shader, retType),
std::move(parameterTypes)
};
}
UInt32 SpirvWriter::GetConstantId(const ShaderConstantValue& value) const
{
return m_currentState->constantTypeCache.GetId(*SpirvConstantCache::BuildConstant(value));
@@ -513,18 +537,7 @@ namespace Nz
UInt32 SpirvWriter::GetFunctionTypeId(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type));
return m_currentState->constantTypeCache.GetId({
SpirvConstantCache::Function {
SpirvConstantCache::BuildType(*m_context.shader, retType),
std::move(parameterTypes)
}
});
return m_currentState->constantTypeCache.GetId({ BuildFunctionType(retType, parameters) });
}
auto SpirvWriter::GetBuiltinVariable(ShaderNodes::BuiltinEntry builtin) const -> const ExtVar&
@@ -602,6 +615,22 @@ namespace Nz
return it->second;
}
UInt32 SpirvWriter::ReadParameterVariable(const std::string& name)
{
auto it = m_currentState->parameterIds.find(name);
assert(it != m_currentState->parameterIds.end());
return ReadVariable(it.value());
}
std::optional<UInt32> SpirvWriter::ReadParameterVariable(const std::string& name, OnlyCache)
{
auto it = m_currentState->parameterIds.find(name);
assert(it != m_currentState->parameterIds.end());
return ReadVariable(it.value(), OnlyCache{});
}
UInt32 SpirvWriter::ReadUniformVariable(const std::string& name)
{
auto it = m_currentState->uniformIds.find(name);
@@ -623,7 +652,7 @@ namespace Nz
if (!var.valueId.has_value())
{
UInt32 resultId = AllocateResultId();
m_currentState->instructions.Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
m_currentState->functionBlocks.back().Append(SpirvOp::OpLoad, var.typeId, resultId, var.varId);
var.valueId = resultId;
}
@@ -646,18 +675,7 @@ namespace Nz
UInt32 SpirvWriter::RegisterFunctionType(ShaderExpressionType retType, const std::vector<ShaderAst::FunctionParameter>& parameters)
{
std::vector<SpirvConstantCache::TypePtr> parameterTypes;
parameterTypes.reserve(parameters.size());
for (const auto& parameter : parameters)
parameterTypes.push_back(SpirvConstantCache::BuildType(*m_context.shader, parameter.type));
return m_currentState->constantTypeCache.Register({
SpirvConstantCache::Function {
SpirvConstantCache::BuildType(*m_context.shader, retType),
std::move(parameterTypes)
}
});
return m_currentState->constantTypeCache.Register({ BuildFunctionType(retType, parameters) });
}
UInt32 SpirvWriter::RegisterPointerType(ShaderExpressionType type, SpirvStorageClass storageClass)