Network/ENetHost: Dismiss external peer connection if listen address is loopback

This commit is contained in:
Jérôme Leclercq 2017-09-01 15:22:27 +02:00
parent 7074876d68
commit e37a7ad5fd
2 changed files with 16 additions and 9 deletions

View File

@ -55,8 +55,8 @@ namespace Nz
ENetPeer* Connect(const String& hostName, NetProtocol protocol = NetProtocol_Any, const String& service = "http", ResolveError* error = nullptr, std::size_t channelCount = 0, UInt32 data = 0); ENetPeer* Connect(const String& hostName, NetProtocol protocol = NetProtocol_Any, const String& service = "http", ResolveError* error = nullptr, std::size_t channelCount = 0, UInt32 data = 0);
inline bool Create(NetProtocol protocol, UInt16 port, std::size_t peerCount, std::size_t channelCount = 0); inline bool Create(NetProtocol protocol, UInt16 port, std::size_t peerCount, std::size_t channelCount = 0);
bool Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount = 0); bool Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount = 0);
bool Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth); bool Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth);
void Destroy(); void Destroy();
void Flush(); void Flush();
@ -152,6 +152,7 @@ namespace Nz
UInt32 m_totalReceivedPackets; UInt32 m_totalReceivedPackets;
UInt64 m_totalSentData; UInt64 m_totalSentData;
UInt64 m_totalReceivedData; UInt64 m_totalReceivedData;
bool m_allowsIncomingConnections;
bool m_continueSending; bool m_continueSending;
bool m_isSimulationEnabled; bool m_isSimulationEnabled;
bool m_recalculateBandwidthLimits; bool m_recalculateBandwidthLimits;

View File

@ -143,14 +143,14 @@ namespace Nz
return Connect(hostnameAddress, channelCount, data); return Connect(hostnameAddress, channelCount, data);
} }
bool ENetHost::Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount) bool ENetHost::Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount)
{ {
return Create(address, peerCount, channelCount, 0, 0); return Create(listenAddress, peerCount, channelCount, 0, 0);
} }
bool ENetHost::Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth) bool ENetHost::Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth)
{ {
NazaraAssert(address.IsValid(), "Invalid listening address"); NazaraAssert(listenAddress.IsValid(), "Invalid listening address");
if (peerCount > ENetConstants::ENetProtocol_MaximumPeerId) if (peerCount > ENetConstants::ENetProtocol_MaximumPeerId)
{ {
@ -158,10 +158,11 @@ namespace Nz
return false; return false;
} }
if (!InitSocket(address)) if (!InitSocket(listenAddress))
return false; return false;
m_address = address; m_address = listenAddress;
m_allowsIncomingConnections = (listenAddress.IsValid() && !listenAddress.IsLoopback());
m_randomSeed = *reinterpret_cast<UInt32*>(this); m_randomSeed = *reinterpret_cast<UInt32*>(this);
m_randomSeed += s_randomGenerator(); m_randomSeed += s_randomGenerator();
m_randomSeed = (m_randomSeed << 16) | (m_randomSeed >> 16); m_randomSeed = (m_randomSeed << 16) | (m_randomSeed >> 16);
@ -236,6 +237,8 @@ namespace Nz
break; break;
} }
if (!m_allowsIncomingConnections && m_connectedPeers == 0)
switch (ReceiveIncomingCommands(event)) switch (ReceiveIncomingCommands(event))
{ {
case 1: case 1:
@ -323,7 +326,7 @@ namespace Nz
m_socket.SetReceiveBufferSize(ENetConstants::ENetHost_ReceiveBufferSize); m_socket.SetReceiveBufferSize(ENetConstants::ENetHost_ReceiveBufferSize);
m_socket.SetSendBufferSize(ENetConstants::ENetHost_SendBufferSize); m_socket.SetSendBufferSize(ENetConstants::ENetHost_SendBufferSize);
if (!address.IsLoopback()) if (address.IsValid() && !address.IsLoopback())
{ {
if (m_socket.Bind(address) != SocketState_Bound) if (m_socket.Bind(address) != SocketState_Bound)
{ {
@ -407,6 +410,9 @@ namespace Nz
ENetPeer* ENetHost::HandleConnect(ENetProtocolHeader* /*header*/, ENetProtocol* command) ENetPeer* ENetHost::HandleConnect(ENetProtocolHeader* /*header*/, ENetProtocol* command)
{ {
if (!m_allowsIncomingConnections)
return nullptr;
UInt32 channelCount = NetToHost(command->connect.channelCount); UInt32 channelCount = NetToHost(command->connect.channelCount);
if (channelCount < ENetProtocol_MinimumChannelCount || channelCount > ENetProtocol_MaximumChannelCount) if (channelCount < ENetProtocol_MinimumChannelCount || channelCount > ENetProtocol_MaximumChannelCount)