Network/ENetHost: Dismiss external peer connection if listen address is loopback

This commit is contained in:
Jérôme Leclercq 2017-09-01 15:22:27 +02:00
parent 7074876d68
commit e37a7ad5fd
2 changed files with 16 additions and 9 deletions

View File

@ -55,8 +55,8 @@ namespace Nz
ENetPeer* Connect(const String& hostName, NetProtocol protocol = NetProtocol_Any, const String& service = "http", ResolveError* error = nullptr, std::size_t channelCount = 0, UInt32 data = 0);
inline bool Create(NetProtocol protocol, UInt16 port, std::size_t peerCount, std::size_t channelCount = 0);
bool Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount = 0);
bool Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth);
bool Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount = 0);
bool Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth);
void Destroy();
void Flush();
@ -152,6 +152,7 @@ namespace Nz
UInt32 m_totalReceivedPackets;
UInt64 m_totalSentData;
UInt64 m_totalReceivedData;
bool m_allowsIncomingConnections;
bool m_continueSending;
bool m_isSimulationEnabled;
bool m_recalculateBandwidthLimits;

View File

@ -143,14 +143,14 @@ namespace Nz
return Connect(hostnameAddress, channelCount, data);
}
bool ENetHost::Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount)
bool ENetHost::Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount)
{
return Create(address, peerCount, channelCount, 0, 0);
return Create(listenAddress, peerCount, channelCount, 0, 0);
}
bool ENetHost::Create(const IpAddress& address, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth)
bool ENetHost::Create(const IpAddress& listenAddress, std::size_t peerCount, std::size_t channelCount, UInt32 incomingBandwidth, UInt32 outgoingBandwidth)
{
NazaraAssert(address.IsValid(), "Invalid listening address");
NazaraAssert(listenAddress.IsValid(), "Invalid listening address");
if (peerCount > ENetConstants::ENetProtocol_MaximumPeerId)
{
@ -158,10 +158,11 @@ namespace Nz
return false;
}
if (!InitSocket(address))
if (!InitSocket(listenAddress))
return false;
m_address = address;
m_address = listenAddress;
m_allowsIncomingConnections = (listenAddress.IsValid() && !listenAddress.IsLoopback());
m_randomSeed = *reinterpret_cast<UInt32*>(this);
m_randomSeed += s_randomGenerator();
m_randomSeed = (m_randomSeed << 16) | (m_randomSeed >> 16);
@ -236,6 +237,8 @@ namespace Nz
break;
}
if (!m_allowsIncomingConnections && m_connectedPeers == 0)
switch (ReceiveIncomingCommands(event))
{
case 1:
@ -323,7 +326,7 @@ namespace Nz
m_socket.SetReceiveBufferSize(ENetConstants::ENetHost_ReceiveBufferSize);
m_socket.SetSendBufferSize(ENetConstants::ENetHost_SendBufferSize);
if (!address.IsLoopback())
if (address.IsValid() && !address.IsLoopback())
{
if (m_socket.Bind(address) != SocketState_Bound)
{
@ -407,6 +410,9 @@ namespace Nz
ENetPeer* ENetHost::HandleConnect(ENetProtocolHeader* /*header*/, ENetProtocol* command)
{
if (!m_allowsIncomingConnections)
return nullptr;
UInt32 channelCount = NetToHost(command->connect.channelCount);
if (channelCount < ENetProtocol_MinimumChannelCount || channelCount > ENetProtocol_MaximumChannelCount)