# HG changeset patch # User Patric Stout # Date 2024-01-02 21:05:25 # Node ID 50bd98948184894c35687338986c96e1949f1f71 # Parent d93207e77be122075a9744cf8adfa1f6ad246fa7 Fix: race-conditions in GUI updates when downloading HTTP files (#11639) diff --git a/src/network/core/http.h b/src/network/core/http.h --- a/src/network/core/http.h +++ b/src/network/core/http.h @@ -30,7 +30,7 @@ struct HTTPCallback { * @param length the amount of received data, 0 when all data has been received. * @note When nullptr is sent the HTTP socket handler is closed/freed. */ - virtual void OnReceiveData(const char *data, size_t length) = 0; + virtual void OnReceiveData(std::unique_ptr data, size_t length) = 0; /** * Check if there is a request to cancel the transfer. diff --git a/src/network/core/http_curl.cpp b/src/network/core/http_curl.cpp --- a/src/network/core/http_curl.cpp +++ b/src/network/core/http_curl.cpp @@ -17,6 +17,7 @@ #include "../network_internal.h" #include "http.h" +#include "http_shared.h" #include #include @@ -44,6 +45,11 @@ static auto _certificate_directories = { }; #endif /* UNIX */ +static std::vector _http_callbacks; +static std::vector _new_http_callbacks; +static std::mutex _http_callback_mutex; +static std::mutex _new_http_callback_mutex; + /** Single HTTP request. */ class NetworkHTTPRequest { public: @@ -59,11 +65,19 @@ public: callback(callback), data(data) { + std::lock_guard lock(_new_http_callback_mutex); + _new_http_callbacks.push_back(&this->callback); } - const std::string uri; ///< URI to connect to. - HTTPCallback *callback; ///< Callback to send data back on. - const std::string data; ///< Data to send, if any. + ~NetworkHTTPRequest() + { + std::lock_guard lock(_http_callback_mutex); + _http_callbacks.erase(std::remove(_http_callbacks.begin(), _http_callbacks.end(), &this->callback), _http_callbacks.end()); + } + + const std::string uri; ///< URI to connect to. + HTTPThreadSafeCallback callback; ///< Callback to send data back on. + const std::string data; ///< Data to send, if any. }; static std::thread _http_thread; @@ -92,6 +106,20 @@ static std::string _http_ca_path = ""; /* static */ void NetworkHTTPSocketHandler::HTTPReceive() { + std::lock_guard lock(_http_callback_mutex); + + { + std::lock_guard lock_new(_new_http_callback_mutex); + if (!_new_http_callbacks.empty()) { + /* We delay adding new callbacks, as HandleQueue() below might add a new callback. */ + _http_callbacks.insert(_http_callbacks.end(), _new_http_callbacks.begin(), _new_http_callbacks.end()); + _new_http_callbacks.clear(); + } + } + + for (auto &callback : _http_callbacks) { + callback->HandleQueue(); + } } void HttpThread() @@ -163,11 +191,16 @@ void HttpThread() /* Setup our (C-style) callback function which we pipe back into the callback. */ curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, +[](char *ptr, size_t size, size_t nmemb, void *userdata) -> size_t { Debug(net, 4, "HTTP callback: {} bytes", size * nmemb); - HTTPCallback *callback = static_cast(userdata); - callback->OnReceiveData(ptr, size * nmemb); + HTTPThreadSafeCallback *callback = static_cast(userdata); + + /* Copy the buffer out of CURL. OnReceiveData() will free it when done. */ + std::unique_ptr buffer = std::make_unique(size * nmemb); + memcpy(buffer.get(), ptr, size * nmemb); + callback->OnReceiveData(std::move(buffer), size * nmemb); + return size * nmemb; }); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, request->callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &request->callback); /* Create a callback from which we can cancel. Sadly, there is no other * thread-safe way to do this. If the connection went idle, it can take @@ -175,10 +208,10 @@ void HttpThread() * do about this. */ curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, +[](void *userdata, curl_off_t /*dltotal*/, curl_off_t /*dlnow*/, curl_off_t /*ultotal*/, curl_off_t /*ulnow*/) -> int { - const HTTPCallback *callback = static_cast(userdata); - return (callback->IsCancelled() || _http_thread_exit) ? 1 : 0; + const HTTPThreadSafeCallback *callback = static_cast(userdata); + return (callback->cancelled || _http_thread_exit) ? 1 : 0; }); - curl_easy_setopt(curl, CURLOPT_XFERINFODATA, request->callback); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &request->callback); /* Perform the request. */ CURLcode res = curl_easy_perform(curl); @@ -187,15 +220,18 @@ void HttpThread() if (res == CURLE_OK) { Debug(net, 1, "HTTP request succeeded"); - request->callback->OnReceiveData(nullptr, 0); + request->callback.OnReceiveData(nullptr, 0); } else { long status_code = 0; curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &status_code); /* No need to be verbose about rate limiting. */ - Debug(net, (request->callback->IsCancelled() || _http_thread_exit || status_code == HTTP_429_TOO_MANY_REQUESTS) ? 1 : 0, "HTTP request failed: status_code: {}, error: {}", status_code, curl_easy_strerror(res)); - request->callback->OnFailure(); + Debug(net, (request->callback.cancelled || _http_thread_exit || status_code == HTTP_429_TOO_MANY_REQUESTS) ? 1 : 0, "HTTP request failed: status_code: {}, error: {}", status_code, curl_easy_strerror(res)); + request->callback.OnFailure(); } + + /* Wait till the callback tells us all data is dequeued. */ + request->callback.WaitTillEmpty(); } curl_easy_cleanup(curl); diff --git a/src/network/core/http_shared.h b/src/network/core/http_shared.h new file mode 100644 --- /dev/null +++ b/src/network/core/http_shared.h @@ -0,0 +1,118 @@ +/* + * This file is part of OpenTTD. + * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2. + * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see . + */ + +/** + * @file http_shared.h Shared functions for implementations of HTTP requests. + */ + +#ifndef NETWORK_CORE_HTTP_SHARED_H +#define NETWORK_CORE_HTTP_SHARED_H + +#include "http.h" + +#include +#include +#include + +/** Converts a HTTPCallback to a Thread-Safe variant. */ +class HTTPThreadSafeCallback { +private: + /** Entries on the queue for later handling. */ + class Callback { + public: + Callback(std::unique_ptr data, size_t length) : data(std::move(data)), length(length), failure(false) {} + Callback() : data(nullptr), length(0), failure(true) {} + + std::unique_ptr data; + size_t length; + bool failure; + }; + +public: + /** + * Similar to HTTPCallback::OnFailure, but thread-safe. + */ + void OnFailure() + { + std::lock_guard lock(this->mutex); + this->queue.emplace_back(); + } + + /** + * Similar to HTTPCallback::OnReceiveData, but thread-safe. + */ + void OnReceiveData(std::unique_ptr data, size_t length) + { + std::lock_guard lock(this->mutex); + this->queue.emplace_back(std::move(data), length); + } + + /** + * Process everything on the queue. + * + * Should be called from the Game Thread. + */ + void HandleQueue() + { + this->cancelled = callback->IsCancelled(); + + std::lock_guard lock(this->mutex); + + for (auto &item : this->queue) { + if (item.failure) { + this->callback->OnFailure(); + } else { + this->callback->OnReceiveData(std::move(item.data), item.length); + } + } + + this->queue.clear(); + this->queue_cv.notify_all(); + } + + /** + * Wait till the queue is dequeued. + */ + void WaitTillEmpty() + { + std::unique_lock lock(this->mutex); + + while (!queue.empty()) { + this->queue_cv.wait(lock); + } + } + + /** + * Check if the queue is empty. + */ + bool IsQueueEmpty() + { + std::lock_guard lock(this->mutex); + return this->queue.empty(); + } + + HTTPThreadSafeCallback(HTTPCallback *callback) : callback(callback) {} + + ~HTTPThreadSafeCallback() + { + std::lock_guard lock(this->mutex); + + /* Clear the list and notify explicitly. */ + queue.clear(); + queue_cv.notify_all(); + } + + std::atomic cancelled = false; + +private: + HTTPCallback *callback; ///< The callback to send data back on. + std::mutex mutex; ///< Mutex to protect the queue. + std::vector queue; ///< Queue of data to send back. + std::condition_variable queue_cv; ///< Condition variable to wait for the queue to be empty. +}; + +#endif /* NETWORK_CORE_HTTP_SHARED_H */ diff --git a/src/network/core/http_winhttp.cpp b/src/network/core/http_winhttp.cpp --- a/src/network/core/http_winhttp.cpp +++ b/src/network/core/http_winhttp.cpp @@ -15,6 +15,7 @@ #include "../network_internal.h" #include "http.h" +#include "http_shared.h" #include #include @@ -26,9 +27,9 @@ static HINTERNET _winhttp_session = null /** Single HTTP request. */ class NetworkHTTPRequest { private: - const std::wstring uri; ///< URI to connect to. - HTTPCallback *callback; ///< Callback to send data back on. - const std::string data; ///< Data to send, if any. + const std::wstring uri; ///< URI to connect to. + HTTPThreadSafeCallback callback; ///< Callback to send data back on. + const std::string data; ///< Data to send, if any. HINTERNET connection = nullptr; ///< Current connection object. HINTERNET request = nullptr; ///< Current request object. @@ -49,6 +50,11 @@ static std::vector static std::vector _new_http_requests; static std::mutex _new_http_requests_mutex; +static std::vector _http_callbacks; +static std::vector _new_http_callbacks; +static std::mutex _http_callback_mutex; +static std::mutex _new_http_callback_mutex; + /** * Create a new HTTP request. * @@ -61,6 +67,8 @@ NetworkHTTPRequest::NetworkHTTPRequest(c callback(callback), data(data) { + std::lock_guard lock(_new_http_callback_mutex); + _new_http_callbacks.push_back(&this->callback); } static std::string GetLastErrorAsString() @@ -113,7 +121,7 @@ void NetworkHTTPRequest::WinHttpCallback if (this->depth++ > 5) { Debug(net, 0, "HTTP request failed: too many redirects"); this->finished = true; - this->callback->OnFailure(); + this->callback.OnFailure(); return; } break; @@ -136,7 +144,7 @@ void NetworkHTTPRequest::WinHttpCallback /* No need to be verbose about rate limiting. */ Debug(net, status_code == HTTP_429_TOO_MANY_REQUESTS ? 1 : 0, "HTTP request failed: status-code {}", status_code); this->finished = true; - this->callback->OnFailure(); + this->callback.OnFailure(); return; } @@ -150,17 +158,15 @@ void NetworkHTTPRequest::WinHttpCallback DWORD size = *(DWORD *)info; /* Next step: read the data in a temporary allocated buffer. - * The buffer will be free'd in the next step. */ - char *buffer = size == 0 ? nullptr : MallocT(size); + * The buffer will be free'd by OnReceiveData() in the next step. */ + char *buffer = size == 0 ? nullptr : new char[size]; WinHttpReadData(this->request, buffer, size, 0); } break; case WINHTTP_CALLBACK_STATUS_READ_COMPLETE: Debug(net, 4, "HTTP callback: {} bytes", length); - this->callback->OnReceiveData(static_cast(info), length); - /* Free the temporary buffer that was allocated in the previous step. */ - free(info); + this->callback.OnReceiveData(std::unique_ptr(static_cast(info)), length); if (length == 0) { /* Next step: no more data available: request is finished. */ @@ -177,13 +183,13 @@ void NetworkHTTPRequest::WinHttpCallback case WINHTTP_CALLBACK_STATUS_REQUEST_ERROR: Debug(net, 0, "HTTP request failed: {}", GetLastErrorAsString()); this->finished = true; - this->callback->OnFailure(); + this->callback.OnFailure(); break; default: Debug(net, 0, "HTTP request failed: unexepected callback code 0x{:x}", code); this->finished = true; - this->callback->OnFailure(); + this->callback.OnFailure(); return; } } @@ -227,7 +233,7 @@ void NetworkHTTPRequest::Connect() if (this->connection == nullptr) { Debug(net, 0, "HTTP request failed: {}", GetLastErrorAsString()); this->finished = true; - this->callback->OnFailure(); + this->callback.OnFailure(); return; } @@ -237,7 +243,7 @@ void NetworkHTTPRequest::Connect() Debug(net, 0, "HTTP request failed: {}", GetLastErrorAsString()); this->finished = true; - this->callback->OnFailure(); + this->callback.OnFailure(); return; } @@ -258,14 +264,14 @@ void NetworkHTTPRequest::Connect() */ bool NetworkHTTPRequest::Receive() { - if (this->callback->IsCancelled()) { + if (this->callback.cancelled && !this->finished) { Debug(net, 1, "HTTP request failed: cancelled by user"); this->finished = true; - this->callback->OnFailure(); - return true; + this->callback.OnFailure(); + /* Fall-through, as we are waiting for IsQueueEmpty() to happen. */ } - return this->finished; + return this->finished && this->callback.IsQueueEmpty(); } /** @@ -279,6 +285,9 @@ NetworkHTTPRequest::~NetworkHTTPRequest( WinHttpCloseHandle(this->request); WinHttpCloseHandle(this->connection); } + + std::lock_guard lock(_http_callback_mutex); + _http_callbacks.erase(std::remove(_http_callbacks.begin(), _http_callbacks.end(), &this->callback), _http_callbacks.end()); } /* static */ void NetworkHTTPSocketHandler::Connect(const std::string &uri, HTTPCallback *callback, const std::string data) @@ -292,6 +301,25 @@ NetworkHTTPRequest::~NetworkHTTPRequest( /* static */ void NetworkHTTPSocketHandler::HTTPReceive() { + /* Process all callbacks. */ + { + std::lock_guard lock(_http_callback_mutex); + + { + std::lock_guard lock(_new_http_callback_mutex); + if (!_new_http_callbacks.empty()) { + /* We delay adding new callbacks, as HandleQueue() below might add a new callback. */ + _http_callbacks.insert(_http_callbacks.end(), _new_http_callbacks.begin(), _new_http_callbacks.end()); + _new_http_callbacks.clear(); + } + } + + for (auto &callback : _http_callbacks) { + callback->HandleQueue(); + } + } + + /* Process all requests. */ { std::lock_guard lock(_new_http_requests_mutex); if (!_new_http_requests.empty()) { diff --git a/src/network/network_content.cpp b/src/network/network_content.cpp --- a/src/network/network_content.cpp +++ b/src/network/network_content.cpp @@ -602,19 +602,21 @@ void ClientNetworkContentSocketHandler:: } } -void ClientNetworkContentSocketHandler::OnReceiveData(const char *data, size_t length) +void ClientNetworkContentSocketHandler::OnReceiveData(std::unique_ptr data, size_t length) { - assert(data == nullptr || length != 0); + assert(data.get() == nullptr || length != 0); /* Ignore any latent data coming from a connection we closed. */ - if (this->http_response_index == -2) return; + if (this->http_response_index == -2) { + return; + } this->lastActivity = std::chrono::steady_clock::now(); if (this->http_response_index == -1) { if (data != nullptr) { /* Append the rest of the response. */ - this->http_response.insert(this->http_response.end(), data, data + length); + this->http_response.insert(this->http_response.end(), data.get(), data.get() + length); return; } else { /* Make sure the response is properly terminated. */ @@ -627,13 +629,14 @@ void ClientNetworkContentSocketHandler:: if (data != nullptr) { /* We have data, so write it to the file. */ - if (fwrite(data, 1, length, this->curFile) != length) { + if (fwrite(data.get(), 1, length, this->curFile) != length) { /* Writing failed somehow, let try via the old method. */ this->OnFailure(); } else { /* Just received the data. */ this->OnDownloadProgress(this->curInfo, (int)length); } + /* Nothing more to do now. */ return; } diff --git a/src/network/network_content.h b/src/network/network_content.h --- a/src/network/network_content.h +++ b/src/network/network_content.h @@ -95,7 +95,7 @@ protected: void OnDownloadComplete(ContentID cid) override; void OnFailure() override; - void OnReceiveData(const char *data, size_t length) override; + void OnReceiveData(std::unique_ptr data, size_t length) override; bool IsCancelled() const override; bool BeforeDownload(); diff --git a/src/network/network_survey.cpp b/src/network/network_survey.cpp --- a/src/network/network_survey.cpp +++ b/src/network/network_survey.cpp @@ -110,7 +110,7 @@ void NetworkSurveyHandler::OnFailure() this->loaded.notify_all(); } -void NetworkSurveyHandler::OnReceiveData(const char *data, size_t) +void NetworkSurveyHandler::OnReceiveData(std::unique_ptr data, size_t) { if (data == nullptr) { Debug(net, 1, "Survey: survey results sent"); diff --git a/src/network/network_survey.h b/src/network/network_survey.h --- a/src/network/network_survey.h +++ b/src/network/network_survey.h @@ -20,7 +20,7 @@ class NetworkSurveyHandler : public HTTPCallback { protected: void OnFailure() override; - void OnReceiveData(const char *data, size_t length) override; + void OnReceiveData(std::unique_ptr data, size_t length) override; bool IsCancelled() const override { return false; } public: