diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp --- a/src/network/core/address.cpp +++ b/src/network/core/address.cpp @@ -87,7 +87,7 @@ const sockaddr_storage *NetworkAddress:: * bothered to implement the specifications and allow '0' as value * that means "don't care whether it is SOCK_STREAM or SOCK_DGRAM". */ - this->Resolve(this->address.ss_family, SOCK_STREAM, AI_ADDRCONFIG, ResolveLoopProc); + this->Resolve(this->address.ss_family, SOCK_STREAM, AI_ADDRCONFIG, NULL, ResolveLoopProc); } return &this->address; } @@ -146,7 +146,7 @@ bool NetworkAddress::IsInNetmask(char *n return true; } -SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, LoopProc func) +SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, SocketList *sockets, LoopProc func) { struct addrinfo *ai; struct addrinfo hints; @@ -159,6 +159,9 @@ SOCKET NetworkAddress::Resolve(int famil char port_name[6]; seprintf(port_name, lastof(port_name), "%u", this->GetPort()); + /* Setting both hostname to NULL and port to 0 is not allowed. + * As port 0 means bind to any port, the other must mean that + * we want to bind to 'all' IPs. */ if (this->address_length == 0 && StrEmpty(this->hostname)) { strecpy(this->hostname, this->address.ss_family == AF_INET ? "0.0.0.0" : "::", lastof(this->hostname)); } @@ -174,10 +177,16 @@ SOCKET NetworkAddress::Resolve(int famil sock = func(runp); if (sock == INVALID_SOCKET) continue; - this->address_length = runp->ai_addrlen; - assert(sizeof(this->address) >= runp->ai_addrlen); - memcpy(&this->address, runp->ai_addr, runp->ai_addrlen); - break; + if (sockets == NULL) { + this->address_length = runp->ai_addrlen; + assert(sizeof(this->address) >= runp->ai_addrlen); + memcpy(&this->address, runp->ai_addr, runp->ai_addrlen); + break; + } + + NetworkAddress addr(runp->ai_addr, runp->ai_addrlen); + (*sockets)[addr] = sock; + sock = INVALID_SOCKET; } freeaddrinfo (ai); @@ -215,7 +224,7 @@ SOCKET NetworkAddress::Connect() { DEBUG(net, 1, "Connecting to %s", this->GetAddressAsString()); - return this->Resolve(0, SOCK_STREAM, AI_ADDRCONFIG, ConnectLoopProc); + return this->Resolve(0, SOCK_STREAM, AI_ADDRCONFIG, NULL, ConnectLoopProc); } /** @@ -231,7 +240,9 @@ static SOCKET ListenLoopProc(addrinfo *r return INVALID_SOCKET; } - if (!SetNoDelay(sock)) DEBUG(net, 1, "Setting TCP_NODELAY failed"); + if (runp->ai_socktype == SOCK_STREAM && !SetNoDelay(sock)) { + DEBUG(net, 1, "Setting TCP_NODELAY failed"); + } int on = 1; /* The (const char*) cast is needed for windows!! */ @@ -262,9 +273,9 @@ static SOCKET ListenLoopProc(addrinfo *r return sock; } -SOCKET NetworkAddress::Listen(int family, int socktype) +SOCKET NetworkAddress::Listen(int family, int socktype, SocketList *sockets) { - return this->Resolve(family, socktype, AI_ADDRCONFIG | AI_PASSIVE, ListenLoopProc); + return this->Resolve(family, socktype, AI_ADDRCONFIG | AI_PASSIVE, sockets, ListenLoopProc); } #endif /* ENABLE_NETWORK */ diff --git a/src/network/core/address.h b/src/network/core/address.h --- a/src/network/core/address.h +++ b/src/network/core/address.h @@ -10,10 +10,11 @@ #include "os_abstraction.h" #include "config.h" #include "../../string_func.h" -#include "../../core/smallvec_type.hpp" +#include "../../core/smallmap_type.hpp" class NetworkAddress; typedef SmallVector NetworkAddressList; +typedef SmallMap SocketList; /** * Wrapper for (un)resolved network addresses; there's no reason to transform @@ -38,10 +39,11 @@ private: * @param family the type of 'protocol' (IPv4, IPv6) * @param socktype the type of socket (TCP, UDP, etc) * @param flags the flags to send to getaddrinfo + * @param sockets the list of sockets to add the sockets to * @param func the inner working while looping over the address info * @return the resolved socket or INVALID_SOCKET. */ - SOCKET Resolve(int family, int socktype, int flags, LoopProc func); + SOCKET Resolve(int family, int socktype, int flags, SocketList *sockets, LoopProc func); public: /** * Create a network address based on a resolved IP and port @@ -217,9 +219,10 @@ public: * Make the given socket listen. * @param family the type of 'protocol' (IPv4, IPv6) * @param socktype the type of socket (TCP, UDP, etc) - * @return the listening socket or INVALID_SOCKET. + * @param sockets the list of sockets to add the sockets to + * @return the socket (if sockets != NULL) */ - SOCKET Listen(int family, int socktype); + SOCKET Listen(int family, int socktype, SocketList *sockets = NULL); }; #endif /* ENABLE_NETWORK */