diff --git a/src/ShaderNode/DataModels/VecDot.cpp b/src/ShaderNode/DataModels/VecDot.cpp new file mode 100644 index 000000000..47b67b72c --- /dev/null +++ b/src/ShaderNode/DataModels/VecDot.cpp @@ -0,0 +1,167 @@ +#include +#include + +VecDot::VecDot(ShaderGraph& graph) : +ShaderNode(graph) +{ + m_output = std::make_shared(); + UpdateOutput(); +} + +Nz::ShaderAst::ExpressionPtr VecDot::GetExpression(Nz::ShaderAst::ExpressionPtr* expressions, std::size_t count) const +{ + assert(count == 2); + using namespace Nz::ShaderAst; + return BinaryFunc::Build(BinaryIntrinsic::DotProduct, expressions[0], expressions[1]); +} + +QString VecDot::caption() const +{ + static QString caption = "Vector dot"; + return caption; +} + +QString VecDot::name() const +{ + static QString name = "vec_dot"; + return name; +} + +QtNodes::NodeDataType VecDot::dataType(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const +{ + switch (portType) + { + case QtNodes::PortType::In: + { + assert(portIndex == 0 || portIndex == 1); + return VecData::Type(); + } + + case QtNodes::PortType::Out: + { + assert(portIndex == 0); + return FloatData::Type(); + } + } + + assert(false); + throw std::runtime_error("Invalid port type"); +} + +unsigned int VecDot::nPorts(QtNodes::PortType portType) const +{ + switch (portType) + { + case QtNodes::PortType::In: return 2; + case QtNodes::PortType::Out: return 1; + } + + assert(false); + throw std::runtime_error("Invalid port type"); +} + +std::shared_ptr VecDot::outData(QtNodes::PortIndex port) +{ + assert(port == 0); + return m_output; +} + +void VecDot::setInData(std::shared_ptr value, int index) +{ + assert(index == 0 || index == 1); + + std::shared_ptr castedValue; + if (value) + { + assert(dynamic_cast(value.get()) != nullptr); + + castedValue = std::static_pointer_cast(value); + } + + if (index == 0) + m_lhs = std::move(castedValue); + else + m_rhs = std::move(castedValue); + + UpdateOutput(); +} + +QtNodes::NodeValidationState VecDot::validationState() const +{ + if (!m_lhs || !m_rhs) + return QtNodes::NodeValidationState::Error; + + if (m_lhs->componentCount != m_rhs->componentCount) + return QtNodes::NodeValidationState::Error; + + return QtNodes::NodeValidationState::Valid; +} + +QString VecDot::validationMessage() const +{ + if (!m_lhs || !m_rhs) + return "Missing operands"; + + if (m_lhs->componentCount != m_rhs->componentCount) + return "Incompatible components count (left has " + QString::number(m_lhs->componentCount) + ", right has " + QString::number(m_rhs->componentCount) + ")"; + + return QString(); +} +bool VecDot::ComputePreview(QPixmap& pixmap) +{ + if (validationState() != QtNodes::NodeValidationState::Valid) + return false; + + pixmap = QPixmap::fromImage(m_output->preview); + return true; +} + +void VecDot::UpdateOutput() +{ + if (validationState() != QtNodes::NodeValidationState::Valid) + { + m_output->preview = QImage(1, 1, QImage::Format_RGBA8888); + m_output->preview.fill(QColor::fromRgb(0, 0, 0, 0)); + return; + } + + const QImage& leftPreview = m_lhs->preview; + const QImage& rightPreview = m_rhs->preview; + int maxWidth = std::max(leftPreview.width(), rightPreview.width()); + int maxHeight = std::max(leftPreview.height(), rightPreview.height()); + + // Exploit COW + QImage leftResized = leftPreview; + if (leftResized.width() != maxWidth || leftResized.height() != maxHeight) + leftResized = leftResized.scaled(maxWidth, maxHeight); + + QImage rightResized = rightPreview; + if (rightResized.width() != maxWidth || rightResized.height() != maxHeight) + rightResized = rightResized.scaled(maxWidth, maxHeight); + + m_output->preview = QImage(maxWidth, maxHeight, QImage::Format_RGBA8888); + + const uchar* left = leftResized.constBits(); + const uchar* right = rightPreview.constBits(); + uchar* output = m_output->preview.bits(); + + std::size_t pixelCount = maxWidth * maxHeight; + for (std::size_t i = 0; i < pixelCount; ++i) + { + unsigned int acc = 0; + for (std::size_t j = 0; j < m_lhs->componentCount; ++j) + acc += left[j] * right[j] / 255; + + unsigned int result = static_cast(std::min(acc, 255U)); + for (std::size_t j = 0; j < 3; ++j) + *output++ = result; + *output++ = 255; //< leave alpha at maximum + + left += 4; + right += 4; + } + + Q_EMIT dataUpdated(0); + + UpdatePreview(); +} diff --git a/src/ShaderNode/DataModels/VecDot.hpp b/src/ShaderNode/DataModels/VecDot.hpp new file mode 100644 index 000000000..dc92bb5c8 --- /dev/null +++ b/src/ShaderNode/DataModels/VecDot.hpp @@ -0,0 +1,43 @@ +#pragma once + +#ifndef NAZARA_SHADERNODES_VECDOT_HPP +#define NAZARA_SHADERNODES_VECDOT_HPP + +#include +#include +#include + +class VecDot : public ShaderNode +{ + public: + VecDot(ShaderGraph& graph); + ~VecDot() = default; + + Nz::ShaderAst::ExpressionPtr GetExpression(Nz::ShaderAst::ExpressionPtr* expressions, std::size_t count) const override; + + QString caption() const override; + QString name() const override; + + unsigned int nPorts(QtNodes::PortType portType) const override; + + QtNodes::NodeDataType dataType(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const override; + + std::shared_ptr outData(QtNodes::PortIndex port) override; + + void setInData(std::shared_ptr value, int index) override; + + QtNodes::NodeValidationState validationState() const override; + QString validationMessage() const override; + + private: + bool ComputePreview(QPixmap& pixmap) override; + void UpdateOutput(); + + std::shared_ptr m_lhs; + std::shared_ptr m_rhs; + std::shared_ptr m_output; +}; + +#include + +#endif diff --git a/src/ShaderNode/DataModels/VecDot.inl b/src/ShaderNode/DataModels/VecDot.inl new file mode 100644 index 000000000..ba9301bc9 --- /dev/null +++ b/src/ShaderNode/DataModels/VecDot.inl @@ -0,0 +1 @@ +#include diff --git a/src/ShaderNode/ShaderGraph.cpp b/src/ShaderNode/ShaderGraph.cpp index 01b5c6cf9..57b9429e3 100644 --- a/src/ShaderNode/ShaderGraph.cpp +++ b/src/ShaderNode/ShaderGraph.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -400,6 +401,7 @@ std::shared_ptr ShaderGraph::BuildRegistry() RegisterShaderNode(*this, registry, "Texture"); RegisterShaderNode(*this, registry, "Vector operations"); RegisterShaderNode(*this, registry, "Vector operations"); + RegisterShaderNode(*this, registry, "Vector operations"); RegisterShaderNode(*this, registry, "Vector operations"); RegisterShaderNode(*this, registry, "Vector operations"); RegisterShaderNode(*this, registry, "Constants");