diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -68,6 +68,16 @@ void Packet::PrepareToSend() this->pos = 0; // We start reading from here } +/** + * Is it safe to write to the packet, i.e. didn't we run over the buffer? + * @param bytes_to_write The amount of bytes we want to try to write. + * @return True iff the given amount of bytes can be written to the packet. + */ +bool Packet::CanWriteToPacket(size_t bytes_to_write) +{ + return this->size + bytes_to_write < SEND_MTU; +} + /* * The next couple of functions make sure we can send * uint8, uint16, uint32 and uint64 endian-safe @@ -95,7 +105,7 @@ void Packet::Send_bool(bool data) */ void Packet::Send_uint8(uint8 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = data; } @@ -105,7 +115,7 @@ void Packet::Send_uint8(uint8 data) */ void Packet::Send_uint16(uint16 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = GB(data, 0, 8); this->buffer[this->size++] = GB(data, 8, 8); } @@ -116,7 +126,7 @@ void Packet::Send_uint16(uint16 data) */ void Packet::Send_uint32(uint32 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = GB(data, 0, 8); this->buffer[this->size++] = GB(data, 8, 8); this->buffer[this->size++] = GB(data, 16, 8); @@ -129,7 +139,7 @@ void Packet::Send_uint32(uint32 data) */ void Packet::Send_uint64(uint64 data) { - assert(this->size < SEND_MTU - sizeof(data)); + assert(this->CanWriteToPacket(sizeof(data))); this->buffer[this->size++] = GB(data, 0, 8); this->buffer[this->size++] = GB(data, 8, 8); this->buffer[this->size++] = GB(data, 16, 8); @@ -148,8 +158,8 @@ void Packet::Send_uint64(uint64 data) void Packet::Send_string(const char *data) { assert(data != nullptr); - /* The <= *is* valid due to the fact that we are comparing sizes and not the index. */ - assert(this->size + strlen(data) + 1 <= SEND_MTU); + /* Length of the string + 1 for the '\0' termination. */ + assert(this->CanWriteToPacket(strlen(data) + 1)); while ((this->buffer[this->size++] = *data++) != '\0') {} } @@ -162,18 +172,21 @@ void Packet::Send_string(const char *dat /** - * Is it safe to read from the packet, i.e. didn't we run over the buffer ? - * @param bytes_to_read The amount of bytes we want to try to read. + * Is it safe to read from the packet, i.e. didn't we run over the buffer? + * In case \c close_connection is true, the connection will be closed when one would + * overrun the buffer. When it is false, the connection remains untouched. + * @param bytes_to_read The amount of bytes we want to try to read. + * @param close_connection Whether to close the connection if one cannot read that amount. * @return True if that is safe, otherwise false. */ -bool Packet::CanReadFromPacket(uint bytes_to_read) +bool Packet::CanReadFromPacket(size_t bytes_to_read, bool close_connection) { /* Don't allow reading from a quit client/client who send bad data */ if (this->cs->HasClientQuit()) return false; /* Check if variable is within packet-size */ if (this->pos + bytes_to_read > this->size) { - this->cs->NetworkSocketHandler::CloseConnection(); + if (close_connection) this->cs->NetworkSocketHandler::CloseConnection(); return false; } @@ -235,7 +248,7 @@ uint8 Packet::Recv_uint8() { uint8 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = this->buffer[this->pos++]; return n; @@ -249,7 +262,7 @@ uint16 Packet::Recv_uint16() { uint16 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint16)this->buffer[this->pos++]; n += (uint16)this->buffer[this->pos++] << 8; @@ -264,7 +277,7 @@ uint32 Packet::Recv_uint32() { uint32 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint32)this->buffer[this->pos++]; n += (uint32)this->buffer[this->pos++] << 8; @@ -281,7 +294,7 @@ uint64 Packet::Recv_uint64() { uint64 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint64)this->buffer[this->pos++]; n += (uint64)this->buffer[this->pos++] << 8;