Changeset - r25245:5872175e0e0c
[Not reviewed]
master
0 6 0
Rubidium - 3 years ago 2021-04-18 08:23:41
rubidium@openttd.org
Codechange: encapsulate writing data from Packets into sockets/files/buffers to prevent packet state modifications outside of the Packet
6 files changed with 90 insertions and 26 deletions:
0 comments (0 inline, 0 general)
src/network/core/packet.h
Show inline comments
 
@@ -89,6 +89,58 @@ public:
 
	size_t RemainingBytesToTransfer() const;
 

	
 
	/**
 
	 * Transfer data from the packet to the given function. It starts reading at the
 
	 * position the last transfer stopped.
 
	 * See Packet::TransferIn for more information about transferring data to functions.
 
	 * @param transfer_function The function to pass the buffer as second parameter and the
 
	 *                          amount to write as third parameter. It returns the amount that
 
	 *                          was written or -1 upon errors.
 
	 * @param limit             The maximum amount of bytes to transfer.
 
	 * @param destination       The first parameter of the transfer function.
 
	 * @param args              The fourth and further parameters to the transfer function, if any.
 
	 * @return The return value of the transfer_function.
 
	 */
 
	template <
 
		typename A = size_t, ///< The type for the amount to be passed, so it can be cast to the right type.
 
		typename F,          ///< The type of the function.
 
		typename D,          ///< The type of the destination.
 
		typename ... Args>   ///< The types of the remaining arguments to the function.
 
	ssize_t TransferOutWithLimit(F transfer_function, size_t limit, D destination, Args&& ... args)
 
	{
 
		size_t amount = std::min(this->RemainingBytesToTransfer(), limit);
 
		if (amount == 0) return 0;
 

	
 
		assert(this->pos < this->buffer.size());
 
		assert(this->pos + amount <= this->buffer.size());
 
		/* Making buffer a char means casting a lot in the Recv/Send functions. */
 
		const char *output_buffer = reinterpret_cast<const char*>(this->buffer + this->pos);
 
		ssize_t bytes = transfer_function(destination, output_buffer, static_cast<A>(amount), std::forward<Args>(args)...);
 
		if (bytes > 0) this->pos += bytes;
 
		return bytes;
 
	}
 

	
 
	/**
 
	 * Transfer data from the packet to the given function. It starts reading at the
 
	 * position the last transfer stopped.
 
	 * See Packet::TransferIn for more information about transferring data to functions.
 
	 * @param transfer_function The function to pass the buffer as second parameter and the
 
	 *                          amount to write as third parameter. It returns the amount that
 
	 *                          was written or -1 upon errors.
 
	 * @param destination       The first parameter of the transfer function.
 
	 * @param args              The fourth and further parameters to the transfer function, if any.
 
	 * @tparam A    The type for the amount to be passed, so it can be cast to the right type.
 
	 * @tparam F    The type of the transfer_function.
 
	 * @tparam D    The type of the destination.
 
	 * @tparam Args The types of the remaining arguments to the function.
 
	 * @return The return value of the transfer_function.
 
	 */
 
	template <typename A = size_t, typename F, typename D, typename ... Args>
 
	ssize_t TransferOut(F transfer_function, D destination, Args&& ... args)
 
	{
 
		return TransferOutWithLimit<A>(transfer_function, std::numeric_limits<size_t>::max(), destination, std::forward<Args>(args)...);
 
	}
 

	
 
	/**
 
	 * Transfer data from the given function into the packet. It starts writing at the
 
	 * position the last transfer stopped.
 
	 *
src/network/core/tcp.cpp
Show inline comments
 
@@ -103,7 +103,7 @@ SendPacketsState NetworkTCPSocketHandler
 

	
 
	p = this->packet_queue;
 
	while (p != nullptr) {
 
		res = send(this->sock, (const char*)p->buffer + p->pos, p->size - p->pos, 0);
 
		res = p->TransferOut<int>(send, this->sock, 0);
 
		if (res == -1) {
 
			int err = GET_LAST_ERROR();
 
			if (err != EWOULDBLOCK) {
 
@@ -122,10 +122,8 @@ SendPacketsState NetworkTCPSocketHandler
 
			return SPS_CLOSED;
 
		}
 

	
 
		p->pos += res;
 

	
 
		/* Is this packet sent? */
 
		if (p->pos == p->size) {
 
		if (p->RemainingBytesToTransfer() == 0) {
 
			/* Go to the next packet */
 
			this->packet_queue = p->next;
 
			delete p;
src/network/core/tcp_listen.h
Show inline comments
 
@@ -63,7 +63,7 @@ public:
 

	
 
					DEBUG(net, 1, "[%s] Banned ip tried to join (%s), refused", Tsocket::GetName(), entry.c_str());
 

	
 
					if (send(s, (const char*)p.buffer, p.size, 0) < 0) {
 
					if (p.TransferOut<int>(send, s, 0) < 0) {
 
						DEBUG(net, 0, "send failed with error %d", GET_LAST_ERROR());
 
					}
 
					closesocket(s);
 
@@ -80,7 +80,7 @@ public:
 
				Packet p(Tfull_packet);
 
				p.PrepareToSend();
 

	
 
				if (send(s, (const char*)p.buffer, p.size, 0) < 0) {
 
				if (p.TransferOut<int>(send, s, 0) < 0) {
 
					DEBUG(net, 0, "send failed with error %d", GET_LAST_ERROR());
 
				}
 
				closesocket(s);
src/network/core/udp.cpp
Show inline comments
 
@@ -99,7 +99,7 @@ void NetworkUDPSocketHandler::SendPacket
 
		}
 

	
 
		/* Send the buffer */
 
		int res = sendto(s.second, (const char*)p->buffer, p->size, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength());
 
		ssize_t res = p->TransferOut<int>(sendto, s.second, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength());
 
		DEBUG(net, 7, "[udp] sendto(%s)", send.GetAddressAsString().c_str());
 

	
 
		/* Check for any errors, but ignore it otherwise */
src/network/network_client.cpp
Show inline comments
 
@@ -35,7 +35,6 @@
 

	
 
/* This file handles all the client-commands */
 

	
 

	
 
/** Read some packets, and when do use that data as initial load filter. */
 
struct PacketReader : LoadFilter {
 
	static const size_t CHUNK = 32 * 1024;  ///< 32 KiB chunks of memory.
 
@@ -60,34 +59,37 @@ struct PacketReader : LoadFilter {
 
	}
 

	
 
	/**
 
	 * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut.
 
	 * @param destination The reader to add the data to.
 
	 * @param source      The buffer to read data from.
 
	 * @param amount      The number of bytes to copy.
 
	 * @return The number of bytes that were copied.
 
	 */
 
	static inline ssize_t TransferOutMemCopy(PacketReader *destination, const char *source, size_t amount)
 
	{
 
		memcpy(destination->buf, source, amount);
 
		destination->buf += amount;
 
		destination->written_bytes += amount;
 
		return amount;
 
	}
 

	
 
	/**
 
	 * Add a packet to this buffer.
 
	 * @param p The packet to add.
 
	 */
 
	void AddPacket(const Packet *p)
 
	void AddPacket(Packet *p)
 
	{
 
		assert(this->read_bytes == 0);
 

	
 
		size_t in_packet = p->size - p->pos;
 
		size_t to_write  = std::min<size_t>(this->bufe - this->buf, in_packet);
 
		const byte *pbuf = p->buffer + p->pos;
 

	
 
		this->written_bytes += in_packet;
 
		if (to_write != 0) {
 
			memcpy(this->buf, pbuf, to_write);
 
			this->buf += to_write;
 
		}
 
		p->TransferOutWithLimit(TransferOutMemCopy, this->bufe - this->buf, this);
 

	
 
		/* Did everything fit in the current chunk, then we're done. */
 
		if (to_write == in_packet) return;
 
		if (p->RemainingBytesToTransfer() == 0) return;
 

	
 
		/* Allocate a new chunk and add the remaining data. */
 
		pbuf += to_write;
 
		to_write   = in_packet - to_write;
 
		this->blocks.push_back(this->buf = CallocT<byte>(CHUNK));
 
		this->bufe = this->buf + CHUNK;
 

	
 
		memcpy(this->buf, pbuf, to_write);
 
		this->buf += to_write;
 
		p->TransferOutWithLimit(TransferOutMemCopy, this->bufe - this->buf, this);
 
	}
 

	
 
	size_t Read(byte *rbuf, size_t size) override
src/network/network_content.cpp
Show inline comments
 
@@ -459,6 +459,18 @@ static bool GunzipFile(const ContentInfo
 
#endif /* defined(WITH_ZLIB) */
 
}
 

	
 
/**
 
 * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut.
 
 * @param file   The file to write data to.
 
 * @param buffer The buffer to write to the file.
 
 * @param amount The number of bytes to write.
 
 * @return The number of bytes that were written.
 
 */
 
static inline ssize_t TransferOutFWrite(FILE *file, const char *buffer, size_t amount)
 
{
 
	return fwrite(buffer, 1, amount, file);
 
}
 

	
 
bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet *p)
 
{
 
	if (this->curFile == nullptr) {
 
@@ -476,8 +488,8 @@ bool ClientNetworkContentSocketHandler::
 
		}
 
	} else {
 
		/* We have a file opened, thus are downloading internal content */
 
		size_t toRead = (size_t)(p->size - p->pos);
 
		if (fwrite(p->buffer + p->pos, 1, toRead, this->curFile) != toRead) {
 
		size_t toRead = p->RemainingBytesToTransfer();
 
		if (toRead != 0 && (size_t)p->TransferOut(TransferOutFWrite, this->curFile) != toRead) {
 
			DeleteWindowById(WC_NETWORK_STATUS_WINDOW, WN_NETWORK_STATUS_WINDOW_CONTENT_DOWNLOAD);
 
			ShowErrorMessage(STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD, STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD_FILE_NOT_WRITABLE, WL_ERROR);
 
			this->Close();
0 comments (0 inline, 0 general)