Shader: Fix parsing of unary/dot/indices/and/or

This commit is contained in:
Jérôme Leclercq 2022-03-21 23:11:28 +01:00
parent 4fa3de519c
commit a54f70fd24
4 changed files with 173 additions and 93 deletions

View File

@ -85,6 +85,10 @@ namespace Nz::ShaderLang
void ParseModuleStatement(std::vector<ShaderAst::ExprValue> attributes);
void ParseVariableDeclaration(std::string& name, ShaderAst::ExpressionValue<ShaderAst::ExpressionType>& type, ShaderAst::ExpressionPtr& initialValue);
ShaderAst::ExpressionPtr BuildIdentifierAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs);
ShaderAst::ExpressionPtr BuildIndexAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs);
ShaderAst::ExpressionPtr BuildBinary(ShaderAst::BinaryType binaryType, ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs);
// Statements
ShaderAst::StatementPtr ParseAliasDeclaration();
ShaderAst::StatementPtr ParseBranchStatement();
@ -110,10 +114,10 @@ namespace Nz::ShaderLang
ShaderAst::ExpressionPtr ParseBinOpRhs(int exprPrecedence, ShaderAst::ExpressionPtr lhs);
ShaderAst::ExpressionPtr ParseConstSelectExpression();
ShaderAst::ExpressionPtr ParseExpression();
std::vector<ShaderAst::ExpressionPtr> ParseExpressionList(TokenType terminationToken);
ShaderAst::ExpressionPtr ParseFloatingPointExpression();
ShaderAst::ExpressionPtr ParseIdentifier();
ShaderAst::ExpressionPtr ParseIntegerExpression();
std::vector<ShaderAst::ExpressionPtr> ParseParameters();
ShaderAst::ExpressionPtr ParseParenthesisExpression();
ShaderAst::ExpressionPtr ParsePrimaryExpression();
ShaderAst::ExpressionPtr ParseStringExpression();

View File

@ -390,6 +390,24 @@ namespace Nz::ShaderLang
Expect(Advance(), TokenType::Semicolon);
}
ShaderAst::ExpressionPtr Parser::BuildIdentifierAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs)
{
if (rhs->GetType() == ShaderAst::NodeType::IdentifierExpression)
return ShaderBuilder::AccessMember(std::move(lhs), { std::move(SafeCast<ShaderAst::IdentifierExpression&>(*rhs).identifier) });
else
return BuildIndexAccess(std::move(lhs), std::move(rhs));
}
ShaderAst::ExpressionPtr Parser::BuildIndexAccess(ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs)
{
return ShaderBuilder::AccessIndex(std::move(lhs), std::move(rhs));
}
ShaderAst::ExpressionPtr Parser::BuildBinary(ShaderAst::BinaryType binaryType, ShaderAst::ExpressionPtr lhs, ShaderAst::ExpressionPtr rhs)
{
return ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
}
ShaderAst::StatementPtr Parser::ParseAliasDeclaration()
{
Expect(Advance(), TokenType::Alias);
@ -1124,59 +1142,25 @@ namespace Nz::ShaderLang
if (tokenPrecedence < exprPrecedence)
return lhs;
bool c = false;
while (currentTokenType == TokenType::Dot || currentTokenType == TokenType::OpenSquareBracket)
{
c = true;
if (currentTokenType == TokenType::Dot)
{
std::unique_ptr<ShaderAst::AccessIdentifierExpression> accessMemberNode = std::make_unique<ShaderAst::AccessIdentifierExpression>();
accessMemberNode->expr = std::move(lhs);
do
{
Consume();
accessMemberNode->identifiers.push_back(ParseIdentifierAsName());
} while (Peek().type == TokenType::Dot);
lhs = std::move(accessMemberNode);
}
else
{
assert(currentTokenType == TokenType::OpenSquareBracket);
std::unique_ptr<ShaderAst::AccessIndexExpression> indexNode = std::make_unique<ShaderAst::AccessIndexExpression>();
indexNode->expr = std::move(lhs);
do
{
Consume();
indexNode->indices.push_back(ParseExpression());
}
while (Peek().type == TokenType::Comma);
Expect(Advance(), TokenType::ClosingSquareBracket);
lhs = std::move(indexNode);
}
currentTokenType = Peek().type;
}
if (currentTokenType == TokenType::OpenParenthesis)
{
// Function call
auto parameters = ParseParameters();
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters));
Consume();
c = true;
// Function call
auto parameters = ParseExpressionList(TokenType::ClosingParenthesis);
lhs = ShaderBuilder::CallFunction(std::move(lhs), std::move(parameters));
continue;
}
if (c)
if (currentTokenType == TokenType::OpenSquareBracket)
{
Consume();
// Indices
auto parameters = ParseExpressionList(TokenType::ClosingSquareBracket);
lhs = ShaderBuilder::AccessIndex(std::move(lhs), std::move(parameters));
continue;
}
Consume();
ShaderAst::ExpressionPtr rhs = ParsePrimaryExpression();
@ -1187,28 +1171,30 @@ namespace Nz::ShaderLang
if (tokenPrecedence < nextTokenPrecedence)
rhs = ParseBinOpRhs(tokenPrecedence + 1, std::move(rhs));
ShaderAst::BinaryType binaryType;
lhs = [&]
{
switch (currentTokenType)
{
case TokenType::Divide: binaryType = ShaderAst::BinaryType::Divide; break;
case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break;
case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break;
case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break;
case TokenType::LogicalAnd: binaryType = ShaderAst::BinaryType::LogicalAnd; break;
case TokenType::LogicalOr: binaryType = ShaderAst::BinaryType::LogicalOr; break;
case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompGt; break;
case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompGe; break;
case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break;
case TokenType::Multiply: binaryType = ShaderAst::BinaryType::Multiply; break;
case TokenType::NotEqual: binaryType = ShaderAst::BinaryType::CompNe; break;
case TokenType::Plus: binaryType = ShaderAst::BinaryType::Add; break;
case TokenType::Dot:
return BuildIdentifierAccess(std::move(lhs), std::move(rhs));
case TokenType::Divide: return BuildBinary(ShaderAst::BinaryType::Divide, std::move(lhs), std::move(rhs));
case TokenType::Equal: return BuildBinary(ShaderAst::BinaryType::CompEq, std::move(lhs), std::move(rhs));
case TokenType::LessThan: return BuildBinary(ShaderAst::BinaryType::CompLt, std::move(lhs), std::move(rhs));
case TokenType::LessThanEqual: return BuildBinary(ShaderAst::BinaryType::CompLe, std::move(lhs), std::move(rhs));
case TokenType::LogicalAnd: return BuildBinary(ShaderAst::BinaryType::LogicalAnd, std::move(lhs), std::move(rhs));
case TokenType::LogicalOr: return BuildBinary(ShaderAst::BinaryType::LogicalOr, std::move(lhs), std::move(rhs));
case TokenType::GreaterThan: return BuildBinary(ShaderAst::BinaryType::CompGt, std::move(lhs), std::move(rhs));
case TokenType::GreaterThanEqual: return BuildBinary(ShaderAst::BinaryType::CompGe, std::move(lhs), std::move(rhs));
case TokenType::Minus: return BuildBinary(ShaderAst::BinaryType::Subtract, std::move(lhs), std::move(rhs));
case TokenType::Multiply: return BuildBinary(ShaderAst::BinaryType::Multiply, std::move(lhs), std::move(rhs));
case TokenType::NotEqual: return BuildBinary(ShaderAst::BinaryType::CompNe, std::move(lhs), std::move(rhs));
case TokenType::Plus: return BuildBinary(ShaderAst::BinaryType::Add, std::move(lhs), std::move(rhs));
default:
throw UnexpectedToken{};
}
}
lhs = ShaderBuilder::Binary(binaryType, std::move(lhs), std::move(rhs));
}
}();
}
}
@ -1237,6 +1223,24 @@ namespace Nz::ShaderLang
return ParseBinOpRhs(0, ParsePrimaryExpression());
}
std::vector<ShaderAst::ExpressionPtr> Parser::ParseExpressionList(TokenType terminationToken)
{
std::vector<ShaderAst::ExpressionPtr> parameters;
bool first = true;
while (Peek().type != terminationToken)
{
if (!first)
Expect(Advance(), TokenType::Comma);
first = false;
parameters.push_back(ParseExpression());
}
Expect(Advance(), terminationToken);
return parameters;
}
ShaderAst::ExpressionPtr Parser::ParseFloatingPointExpression()
{
const Token& floatingPointToken = Expect(Advance(), TokenType::FloatingPointValue);
@ -1257,26 +1261,6 @@ namespace Nz::ShaderLang
return ShaderBuilder::Constant(SafeCast<Int32>(std::get<long long>(integerToken.data))); //< FIXME
}
std::vector<ShaderAst::ExpressionPtr> Parser::ParseParameters()
{
Expect(Advance(), TokenType::OpenParenthesis);
std::vector<ShaderAst::ExpressionPtr> parameters;
bool first = true;
while (Peek().type != TokenType::ClosingParenthesis)
{
if (!first)
Expect(Advance(), TokenType::Comma);
first = false;
parameters.push_back(ParseExpression());
}
Expect(Advance(), TokenType::ClosingParenthesis);
return parameters;
}
ShaderAst::ExpressionPtr Parser::ParseParenthesisExpression()
{
Expect(Advance(), TokenType::OpenParenthesis);
@ -1314,7 +1298,7 @@ namespace Nz::ShaderLang
case TokenType::Minus:
{
Consume();
ShaderAst::ExpressionPtr expr = ParsePrimaryExpression();
ShaderAst::ExpressionPtr expr = ParseExpression();
return ShaderBuilder::Unary(ShaderAst::UnaryType::Minus, std::move(expr));
}
@ -1322,7 +1306,7 @@ namespace Nz::ShaderLang
case TokenType::Plus:
{
Consume();
ShaderAst::ExpressionPtr expr = ParsePrimaryExpression();
ShaderAst::ExpressionPtr expr = ParseExpression();
return ShaderBuilder::Unary(ShaderAst::UnaryType::Plus, std::move(expr));
}
@ -1330,7 +1314,7 @@ namespace Nz::ShaderLang
case TokenType::Not:
{
Consume();
ShaderAst::ExpressionPtr expr = ParsePrimaryExpression();
ShaderAst::ExpressionPtr expr = ParseExpression();
return ShaderBuilder::Unary(ShaderAst::UnaryType::LogicalNot, std::move(expr));
}
@ -1404,12 +1388,12 @@ namespace Nz::ShaderLang
switch (token)
{
case TokenType::Divide: return 80;
case TokenType::Dot: return 100;
case TokenType::Dot: return 150;
case TokenType::Equal: return 50;
case TokenType::LessThan: return 40;
case TokenType::LessThanEqual: return 40;
case TokenType::LogicalAnd: return 120;
case TokenType::LogicalOr: return 140;
case TokenType::LogicalAnd: return 20;
case TokenType::LogicalOr: return 10;
case TokenType::GreaterThan: return 40;
case TokenType::GreaterThanEqual: return 40;
case TokenType::Multiply: return 80;

View File

@ -88,6 +88,97 @@ OpStore
OpBranch
OpLabel
OpReturn
OpFunctionEnd)");
}
WHEN("using a more complex branch")
{
std::string_view nzslSource = R"(
[nzsl_version("1.0")]
module;
struct inputStruct
{
value: f32
}
external
{
[set(0), binding(0)] data: uniform[inputStruct]
}
[entry(frag)]
fn main()
{
let value: f32;
if (data.value > 42.0 || data.value <= 50.0 && data.value < 0.0)
value = 1.0;
else
value = 0.0;
}
)";
Nz::ShaderAst::ModulePtr shaderModule = Nz::ShaderLang::Parse(nzslSource);
shaderModule = SanitizeModule(*shaderModule);
ExpectGLSL(*shaderModule, R"(
void main()
{
float value;
if ((data.value > (42.000000)) || ((data.value <= (50.000000)) && (data.value < (0.000000))))
{
value = 1.000000;
}
else
{
value = 0.000000;
}
}
)");
ExpectNZSL(*shaderModule, R"(
[entry(frag)]
fn main()
{
let value: f32;
if ((data.value > (42.000000)) || ((data.value <= (50.000000)) && (data.value < (0.000000))))
{
value = 1.000000;
}
else
{
value = 0.000000;
}
}
)");
ExpectSPIRV(*shaderModule, R"(
OpFunction
OpLabel
OpVariable
OpAccessChain
OpLoad
OpFOrdGreaterThanEqual
OpAccessChain
OpLoad
OpFOrdLessThanEqual
OpAccessChain
OpLoad
OpFOrdLessThan
OpLogicalAnd
OpLogicalOr
OpSelectionMerge
OpBranchConditional
OpLabel
OpStore
OpBranch
OpLabel
OpStore
OpBranch
OpLabel
OpReturn
OpFunctionEnd)");
}

View File

@ -28,7 +28,7 @@ fn GetValue() -> f32
fn main() -> FragOut
{
let output: FragOut;
output.value = GetValue();
output.value = -GetValue();
return output;
}
@ -49,7 +49,7 @@ layout(location = 0) out float _NzOut_value;
void main()
{
FragOut output_;
output_.value = GetValue();
output_.value = -GetValue();
_NzOut_value = output_.value;
return;
@ -66,7 +66,7 @@ fn GetValue() -> f32
fn main() -> FragOut
{
let output: FragOut;
output.value = GetValue();
output.value = -GetValue();
return output;
}
)");
@ -80,6 +80,7 @@ OpFunction
OpLabel
OpVariable
OpFunctionCall
OpFNegate
OpAccessChain
OpStore
OpLoad