diff --git a/include/Nazara/Network/ENetHost.hpp b/include/Nazara/Network/ENetHost.hpp index eb5c7d6de..789c3d9a6 100644 --- a/include/Nazara/Network/ENetHost.hpp +++ b/include/Nazara/Network/ENetHost.hpp @@ -55,8 +55,8 @@ namespace Nz ENetPeer* Connect(const String& hostName, NetProtocol protocol = NetProtocol_Any, const String& service = "http", ResolveError* error = nullptr, std::size_t channelCount = 0, UInt32 data = 0); inline bool Create(NetProtocol protocol, UInt16 port, std::size_t peerCount, std::size_t channelCount = 0); - bool Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount = 0); - bool Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth); + bool Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount = 0); + bool Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth); void Destroy(); void Flush(); @@ -152,6 +152,7 @@ namespace Nz UInt32 m_totalReceivedPackets; UInt64 m_totalSentData; UInt64 m_totalReceivedData; + bool m_allowsIncomingConnections; bool m_continueSending; bool m_isSimulationEnabled; bool m_recalculateBandwidthLimits; diff --git a/src/Nazara/Network/ENetHost.cpp b/src/Nazara/Network/ENetHost.cpp index 9fdfe284f..3fe2f3a47 100644 --- a/src/Nazara/Network/ENetHost.cpp +++ b/src/Nazara/Network/ENetHost.cpp @@ -143,14 +143,14 @@ namespace Nz return Connect(hostnameAddress, channelCount, data); } - bool ENetHost::Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount) + bool ENetHost::Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount) { - return Create(address, peerCount, channelCount, 0, 0); + return Create(listenAddress, peerCount, channelCount, 0, 0); } - bool ENetHost::Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth) + bool ENetHost::Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth) { - NazaraAssert(address.IsValid(), "Invalid listening address"); + NazaraAssert(listenAddress.IsValid(), "Invalid listening address"); if (peerCount > ENetConstants::ENetProtocol_MaximumPeerId) { @@ -158,10 +158,11 @@ namespace Nz return false; } - if (!InitSocket(address)) + if (!InitSocket(listenAddress)) return false; - m_address = address; + m_address = listenAddress; + m_allowsIncomingConnections = (listenAddress.IsValid() && !listenAddress.IsLoopback()); m_randomSeed = *reinterpret_cast(this); m_randomSeed += s_randomGenerator(); m_randomSeed = (m_randomSeed << 16) | (m_randomSeed >> 16); @@ -236,6 +237,8 @@ namespace Nz break; } + if (!m_allowsIncomingConnections && m_connectedPeers == 0) + switch (ReceiveIncomingCommands(event)) { case 1: @@ -323,7 +326,7 @@ namespace Nz m_socket.SetReceiveBufferSize(ENetConstants::ENetHost_ReceiveBufferSize); m_socket.SetSendBufferSize(ENetConstants::ENetHost_SendBufferSize); - if (!address.IsLoopback()) + if (address.IsValid() && !address.IsLoopback()) { if (m_socket.Bind(address) != SocketState_Bound) { @@ -407,6 +410,9 @@ namespace Nz ENetPeer* ENetHost::HandleConnect(ENetProtocolHeader* /*header*/, ENetProtocol* command) { + if (!m_allowsIncomingConnections) + return nullptr; + UInt32 channelCount = NetToHost(command->connect.channelCount); if (channelCount < ENetProtocol_MinimumChannelCount || channelCount > ENetProtocol_MaximumChannelCount)