Network/SocketPoller: Wait() now returns the number of active sockets, and optionally the last error

It will also ignore the EINTR error on Linux
This commit is contained in:
Jérôme Leclercq
2018-06-12 14:31:49 +02:00
parent 15f84dc712
commit 56b23a2f27
16 changed files with 120 additions and 94 deletions

View File

@@ -41,7 +41,7 @@ namespace Nz
else
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
}
return newClient;
@@ -58,7 +58,7 @@ namespace Nz
if (bind(handle, reinterpret_cast<const sockaddr*>(&nameBuffer), bufferLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return SocketState_NotConnected;
}
@@ -76,7 +76,7 @@ namespace Nz
SocketHandle handle = socket(TranslateNetProtocolToAF(protocol), TranslateSocketTypeToSock(type), 0);
if (handle == InvalidHandle && error != nullptr)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return handle;
}
@@ -129,7 +129,7 @@ namespace Nz
if (errorCode == EADDRNOTAVAIL)
*error = SocketError_ConnectionRefused; //< ConnectionRefused seems more legit than AddressNotAvailable in connect case
else
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
}
return SocketState_NotConnected;
@@ -162,7 +162,7 @@ namespace Nz
if (code)
{
if (error)
*error = TranslateErrnoToResolveError(code);
*error = TranslateErrnoToSocketError(code);
return SocketState_NotConnected;
}
@@ -177,7 +177,7 @@ namespace Nz
else
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return SocketState_NotConnected;
}
@@ -202,7 +202,7 @@ namespace Nz
if (code < 0)
return SocketError_Internal;
return TranslateErrnoToResolveError(code);
return TranslateErrnoToSocketError(code);
}
int SocketImpl::GetLastErrorCode()
@@ -218,7 +218,7 @@ namespace Nz
if (getsockopt(handle, SOL_SOCKET, SO_ERROR, &code, &codeLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return -1;
}
@@ -240,7 +240,7 @@ namespace Nz
if (bind(handle, reinterpret_cast<const sockaddr*>(&nameBuffer), bufferLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return SocketState_NotConnected;
}
@@ -248,7 +248,7 @@ namespace Nz
if (listen(handle, queueSize) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return SocketState_NotConnected;
}
@@ -267,7 +267,7 @@ namespace Nz
if (ioctl(handle, FIONREAD, &availableBytes) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return 0;
}
@@ -286,7 +286,7 @@ namespace Nz
if (getsockopt(handle, SOL_SOCKET, SO_BROADCAST, &code, &codeLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false;
}
@@ -305,7 +305,7 @@ namespace Nz
if (getsockopt(handle, SOL_SOCKET, SO_KEEPALIVE, &code, &codeLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false;
}
@@ -324,7 +324,7 @@ namespace Nz
if (getsockopt(handle, IPPROTO_IP, IP_MTU, &code, &codeLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return 0;
}
@@ -343,7 +343,7 @@ namespace Nz
if (getsockopt(handle, IPPROTO_TCP, TCP_NODELAY, &code, &codeLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false;
}
@@ -362,7 +362,7 @@ namespace Nz
if (getsockopt(handle, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char*>(&code), &codeLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return 0;
}
@@ -385,7 +385,7 @@ namespace Nz
if (getpeername(handle, reinterpret_cast<sockaddr*>(nameBuffer.data()), &bufferLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return IpAddress();
}
@@ -404,7 +404,7 @@ namespace Nz
if (getsockopt(handle, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char*>(&code), &codeLength) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return 0;
}
@@ -432,7 +432,7 @@ namespace Nz
if (errorCode == EINVAL)
*error = SocketError_NoError;
else
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
}
return IpAddress();
@@ -444,7 +444,7 @@ namespace Nz
return IpAddressImpl::FromSockAddr(reinterpret_cast<sockaddr*>(nameBuffer.data()));
}
int SocketImpl::Poll(PollSocket* fdarray, std::size_t nfds, int timeout, SocketError* error)
unsigned int SocketImpl::Poll(PollSocket* fdarray, std::size_t nfds, int timeout, SocketError* error)
{
NazaraAssert(fdarray && nfds > 0, "Invalid fdarray");
@@ -454,12 +454,12 @@ namespace Nz
if (result < 0)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return 0;
}
return result;
return static_cast<unsigned int>(result);
}
bool SocketImpl::Receive(SocketHandle handle, void* buffer, int length, int* read, SocketError* error)
@@ -486,7 +486,7 @@ namespace Nz
default:
{
if (error)
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
return false; //< Error
}
@@ -541,7 +541,7 @@ namespace Nz
default:
{
if (error)
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
return false; //< Error
}
@@ -618,7 +618,7 @@ namespace Nz
default:
{
if (error)
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
return false; //< Error
}
@@ -677,7 +677,7 @@ namespace Nz
default:
{
if (error)
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
return false; //< Error
}
@@ -730,7 +730,7 @@ namespace Nz
default:
{
if (error)
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
return false; //< Error
}
@@ -770,7 +770,7 @@ namespace Nz
default:
{
if (error)
*error = TranslateErrnoToResolveError(errorCode);
*error = TranslateErrnoToSocketError(errorCode);
return false; //< Error
}
@@ -794,7 +794,7 @@ namespace Nz
if (ioctl(handle, FIONBIO, &block) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -813,7 +813,7 @@ namespace Nz
if (setsockopt(handle, SOL_SOCKET, SO_BROADCAST, reinterpret_cast<const char*>(&option), sizeof(option)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -832,7 +832,7 @@ namespace Nz
if (setsockopt(handle, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast<const char*>(&option), sizeof(option)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -854,7 +854,7 @@ namespace Nz
if (setsockopt(handle, SOL_SOCKET, SO_KEEPALIVE, &keepAlive , sizeof(keepAlive)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -862,7 +862,7 @@ namespace Nz
if (setsockopt(handle, IPPROTO_TCP, TCP_KEEPIDLE, &keepIdle, sizeof(keepIdle)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -870,7 +870,7 @@ namespace Nz
if (setsockopt(handle, IPPROTO_TCP, TCP_KEEPINTVL, &keepInterval, sizeof(keepInterval)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -889,7 +889,7 @@ namespace Nz
if (setsockopt(handle, IPPROTO_TCP, TCP_NODELAY, &option, sizeof(option)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -908,7 +908,7 @@ namespace Nz
if (setsockopt(handle, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<const char*>(&option), sizeof(option)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -927,7 +927,7 @@ namespace Nz
if (setsockopt(handle, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<const char*>(&option), sizeof(option)) == SOCKET_ERROR)
{
if (error)
*error = TranslateErrnoToResolveError(GetLastErrorCode());
*error = TranslateErrnoToSocketError(GetLastErrorCode());
return false; //< Error
}
@@ -938,7 +938,7 @@ namespace Nz
return true;
}
SocketError SocketImpl::TranslateErrnoToResolveError(int error)
SocketError SocketImpl::TranslateErrnoToSocketError(int error)
{
switch (error)
{
@@ -974,6 +974,9 @@ namespace Nz
case ECONNREFUSED:
return SocketError_ConnectionRefused;
case EINTR:
return SocketError_Interrupted;
case EMSGSIZE:
return SocketError_DatagramSize;

View File

@@ -60,7 +60,7 @@ namespace Nz
static std::size_t QueryReceiveBufferSize(SocketHandle handle, SocketError* error = nullptr);
static std::size_t QuerySendBufferSize(SocketHandle handle, SocketError* error = nullptr);
static int Poll(PollSocket* fdarray, std::size_t nfds, int timeout, SocketError* error);
static unsigned int Poll(PollSocket* fdarray, std::size_t nfds, int timeout, SocketError* error);
static bool Receive(SocketHandle handle, void* buffer, int length, int* read, SocketError* error);
static bool ReceiveFrom(SocketHandle handle, void* buffer, int length, IpAddress* from, int* read, SocketError* error);
@@ -78,7 +78,7 @@ namespace Nz
static bool SetReceiveBufferSize(SocketHandle handle, std::size_t size, SocketError* error = nullptr);
static bool SetSendBufferSize(SocketHandle handle, std::size_t size, SocketError* error = nullptr);
static SocketError TranslateErrnoToResolveError(int error);
static SocketError TranslateErrnoToSocketError(int error);
static int TranslateNetProtocolToAF(NetProtocol protocol);
static int TranslateSocketTypeToSock(SocketType type);

View File

@@ -76,9 +76,9 @@ namespace Nz
m_readyToWriteSockets.erase(socket);
}
int SocketPollerImpl::Wait(int msTimeout, SocketError* error)
unsigned int SocketPollerImpl::Wait(int msTimeout, SocketError* error)
{
int activeSockets;
unsigned int activeSockets;
// Reset status of sockets
activeSockets = SocketImpl::Poll(m_sockets.data(), m_sockets.size(), static_cast<int>(msTimeout), error);
@@ -87,7 +87,7 @@ namespace Nz
m_readyToWriteSockets.clear();
if (activeSockets > 0U)
{
int socketRemaining = activeSockets;
unsigned int socketRemaining = activeSockets;
for (PollSocket& entry : m_sockets)
{
if (!entry.revents)
@@ -103,7 +103,7 @@ namespace Nz
}
else
{
NazaraWarning("Socket " + String::Number(entry.fd) + " was returned by WSAPoll without POLLRDNORM nor POLLWRNORM events (events: 0x" + String::Number(entry.revents, 16) + ')');
NazaraWarning("Socket " + String::Number(entry.fd) + " was returned by poll without POLLRDNORM nor POLLWRNORM events (events: 0x" + String::Number(entry.revents, 16) + ')');
activeSockets--;
}

View File

@@ -23,13 +23,14 @@ namespace Nz
void Clear();
bool IsReady(SocketHandle socket) const;
bool IsReadyToRead(SocketHandle socket) const;
bool IsReadyToWrite(SocketHandle socket) const;
bool IsRegistered(SocketHandle socket) const;
bool RegisterSocket(SocketHandle socket, SocketPollEventFlags eventFlags);
void UnregisterSocket(SocketHandle socket);
int Wait(int msTimeout, SocketError* error);
unsigned int Wait(int msTimeout, SocketError* error);
private:
std::unordered_set<SocketHandle> m_readyToReadSockets;