From 0360935709cfa3731dbed9d26992e5a6d529a2c3 Mon Sep 17 00:00:00 2001 From: koparasy Date: Mon, 27 Nov 2023 07:57:11 -0800 Subject: [PATCH] Synchronize RMQProducer thread with main thread using future/promises. Clean up buffers --- src/AMSlib/wf/basedb.hpp | 337 +++++++++++++++++++++++++----------- tests/AMSlib/CMakeLists.txt | 2 +- 2 files changed, 241 insertions(+), 98 deletions(-) diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp index 00db8671..7f72107a 100644 --- a/src/AMSlib/wf/basedb.hpp +++ b/src/AMSlib/wf/basedb.hpp @@ -8,16 +8,21 @@ #ifndef __AMS_BASE_DB__ #define __AMS_BASE_DB__ + +#include #include #include #include #include +#include #include #include #include #include #include "AMS.h" +#include "debug.h" +#include "resource_manager.hpp" #include "wf/debug.h" #include "wf/device.hpp" #include "wf/resource_manager.hpp" @@ -51,6 +56,7 @@ using namespace sw::redis; #include #include #include +#include #include #include #include @@ -61,6 +67,7 @@ using namespace sw::redis; #include #include #include +#include #include #include #include @@ -706,7 +713,6 @@ class RedisDB : public BaseDB * * |__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| */ -template struct AMSMsgHeader { /** @brief Heaader size (bytes) */ uint8_t hsize; @@ -728,9 +734,13 @@ struct AMSMsgHeader { * @param[in] in_dim Inputs dimension * @param[in] out_dim Outputs dimension */ - AMSMsgHeader(size_t mpi_rank, size_t num_elem, size_t in_dim, size_t out_dim) + AMSMsgHeader(size_t mpi_rank, + size_t num_elem, + size_t in_dim, + size_t out_dim, + size_t type_size) : hsize(static_cast(AMSMsgHeader::size())), - dtype(static_cast(sizeof(TypeValue))), + dtype(static_cast(type_size)), mpi_rank(static_cast(mpi_rank)), num_elem(static_cast(num_elem)), in_dim(static_cast(in_dim)), @@ -742,7 +752,7 @@ struct AMSMsgHeader { * @brief Return the size of a header in the AMS protocol. * @return The size of a message header in AMS (in byte) */ - static size_t size() + static size_t constexpr size() { return ((sizeof(hsize) + sizeof(dtype) + sizeof(mpi_rank) + sizeof(num_elem) + sizeof(in_dim) + sizeof(out_dim) + @@ -780,12 +790,6 @@ struct AMSMsgHeader { std::memcpy(data_blob + current_offset, &(out_dim), sizeof(out_dim)); current_offset += sizeof(out_dim); - CFATAL(RMQHeader, - current_offset > AMSMsgHeader::size(), - "Offset is %d but header size is %d", - current_offset, - AMSMsgHeader::size()); - return AMSMsgHeader::size(); } }; @@ -794,7 +798,6 @@ struct AMSMsgHeader { /** * @brief Class representing a message for the AMSLib */ -template class AMSMessage { private: @@ -820,6 +823,7 @@ class AMSMessage * @param[in] inputs Inputs * @param[in] outputs Outputs */ + template AMSMessage(int id, size_t num_elements, const std::vector& inputs, @@ -834,20 +838,19 @@ class AMSMessage #ifdef __ENABLE_MPI__ MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); #endif - _total_size = AMSMsgHeader::size() + getDataSize(); + AMSMsgHeader header( + _rank, _num_elements, _input_dim, _output_dim, sizeof(TypeValue)); + + _total_size = AMSMsgHeader::size() + getTotalElements() * sizeof(TypeValue); _data = ams::ResourceManager::allocate(_total_size, AMSResourceType::HOST); - AMSMsgHeader header(_rank, - _num_elements, - _input_dim, - _output_dim); size_t current_offset = header.encode(_data); - DBG(RMQPublisherHandler, "Current Offset is %d", current_offset); current_offset += encode_data(reinterpret_cast(_data + current_offset), inputs, outputs); + DBG(AMSMessage, "Allocated message: %p", _data); } AMSMessage(const AMSMessage&) = delete; @@ -857,6 +860,7 @@ class AMSMessage AMSMessage& operator=(AMSMessage&& other) noexcept { + DBG(AMSMessage, "Move AMSMessage : %p -- %d", other._data, other._id); if (this != &other) { _id = other._id; _num_elements = other._num_elements; @@ -877,14 +881,15 @@ class AMSMessage * @param[in] outputs Outputs * @return The number of bytes in the message or 0 if error */ + template size_t encode_data(TypeValue* data_blob, const std::vector& inputs, const std::vector& outputs) { + size_t offset = 0; size_t x_dim = _input_dim + _output_dim; if (!data_blob) return 0; // Creating the body part of the messages - // TODO: slow method (one copy per element!), improve by reducing number of copies for (size_t i = 0; i < _num_elements; i++) { for (size_t j = 0; j < _input_dim; j++) { data_blob[i * x_dim + j] = inputs[j][i]; @@ -901,12 +906,12 @@ class AMSMessage } /** - * @brief Return the size of the data portion for that message + * @brief Return the total number of elements in this message * @return Size in bytes of the data portion */ - size_t getDataSize() + size_t getTotalElements() const { - return (_num_elements * (_input_dim + _output_dim)) * sizeof(TypeValue); + return (_num_elements * (_input_dim + _output_dim)); } /** @@ -921,6 +926,8 @@ class AMSMessage */ int id() const { return _id; } + int rank() const { return _rank; } + /** * @brief Return the size in bytes of the underlying binary blob * @return Byte size of data pointer @@ -929,8 +936,7 @@ class AMSMessage ~AMSMessage() { - if (_data) - ams::ResourceManager::deallocate(_data, AMSResourceType::HOST); + DBG(AMSMessage, "Destroying message with address %p %d", _data, _id) } }; // class AMSMessage @@ -949,7 +955,6 @@ typedef std::tuple /** * @brief Specific handler for RabbitMQ connections based on libevent. */ -template class RMQConsumerHandler : public AMQP::LibEventHandler { private: @@ -1108,7 +1113,8 @@ class RMQConsumerHandler : public AMQP::LibEventHandler _channel->ack(deliveryTag); std::string msg(message.body(), message.bodySize()); DBG(RMQConsumerHandler, - "message received [tag=%d] : '%s' of size %d B from '%s'/'%s'", + "message received [tag=%d] : '%s' of size %d B from " + "'%s'/'%s'", deliveryTag, msg.c_str(), message.bodySize(), @@ -1141,7 +1147,8 @@ class RMQConsumerHandler : public AMQP::LibEventHandler .onError([&](const char* message) { CFATAL(RMQConsumerHandler, false, - "[ERROR][rank=%d] Error while creating broker queue (%s): %s", + "[ERROR][rank=%d] Error while creating broker queue (%s): " + "%s", _rank, _queue.c_str(), message) @@ -1194,7 +1201,6 @@ class RMQConsumerHandler : public AMQP::LibEventHandler * @brief Class that manages a RabbitMQ broker and handles connection, event * loop and set up various handlers. */ -template class RMQConsumer { private: @@ -1209,7 +1215,7 @@ class RMQConsumer /** @brief The event loop for sender (usually the default one in libevent) */ std::shared_ptr _loop; /** @brief The handler which contains various callbacks for the sender */ - std::shared_ptr> _handler; + std::shared_ptr _handler; /** @brief Queue that contains all the messages received on receiver queue (messages can be popped in) */ std::vector _messages; @@ -1255,8 +1261,7 @@ class RMQConsumer [](struct event_base* event) { event_base_free(event); }); - _handler = - std::make_shared>(_loop, _cacert, _queue); + _handler = std::make_shared(_loop, _cacert, _queue); _connection = new AMQP::TcpConnection(_handler.get(), address); } @@ -1313,10 +1318,10 @@ class RMQConsumer /** * @brief Specific handler for RabbitMQ connections based on libevent. */ -template class RMQPublisherHandler : public AMQP::LibEventHandler { private: + enum ConnectionStatus { FAILED, CONNECTED, CLOSED }; /** @brief Path to TLS certificate */ std::string _cacert; /** @brief The MPI rank (0 if MPI is not used) */ @@ -1334,7 +1339,16 @@ class RMQPublisherHandler : public AMQP::LibEventHandler /** @brief Number of messages successfully acknowledged */ int _nb_msg_ack; + std::promise establish_connection; + std::future established; + + std::promise close_connection; + std::future closed; + public: + std::mutex ptr_mutex; + std::vector data_ptrs; + /** * @brief Constructor * @param[in] loop Event Loop @@ -1357,6 +1371,8 @@ class RMQPublisherHandler : public AMQP::LibEventHandler #ifdef __ENABLE_MPI__ MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); #endif + established = establish_connection.get_future(); + closed = close_connection.get_future(); } /** @@ -1364,43 +1380,75 @@ class RMQPublisherHandler : public AMQP::LibEventHandler * @param[in] data The data pointer * @param[in] data_size The number of bytes in the data pointer */ - void publish(const AMSMessage& msg) + void publish(AMSMessage&& msg) { if (_rchannel) { // publish a message via the reliable-channel _rchannel ->publish("", _queue, reinterpret_cast(msg.data()), msg.size()) - .onAck([&]() { + .onAck([_msg_ptr = msg.data(), + &_nb_msg_ack = _nb_msg_ack, + rank = msg.rank(), + id = msg.id(), + &ptr_mutex = ptr_mutex, + &data_ptrs = this->data_ptrs]() mutable { + const std::lock_guard lock(ptr_mutex); DBG(RMQPublisherHandler, - "[rank=%d] message #%d got acknowledged successfully by RMQ " + "[rank=%d] message #%d (Addr:%p) got acknowledged successfully " + "by " + "RMQ " "server", - _rank, - _nb_msg) + rank, + id, + _msg_ptr) _nb_msg_ack++; + data_ptrs.push_back(_msg_ptr); }) - .onNack([&]() { + .onNack([_msg_ptr = msg.data(), + &_nb_msg_ack = _nb_msg_ack, + rank = msg.rank(), + id = msg.id(), + &ptr_mutex = ptr_mutex, + &data_ptrs = this->data_ptrs]() mutable { + const std::lock_guard lock(ptr_mutex); WARNING(RMQPublisherHandler, "[rank=%d] message #%d received negative acknowledged by " "RMQ " "server", - _rank, - _nb_msg) + rank, + id) + data_ptrs.push_back(_msg_ptr); }) - .onLost([&]() { + .onLost([_msg_ptr = msg.data(), + &_nb_msg_ack = _nb_msg_ack, + rank = msg.rank(), + id = msg.id(), + &ptr_mutex = ptr_mutex, + &data_ptrs = this->data_ptrs]() mutable { + const std::lock_guard lock(ptr_mutex); CFATAL(RMQPublisherHandler, false, "[rank=%d] message #%d likely got lost by RMQ server", - _rank, - _nb_msg) + rank, + id) + data_ptrs.push_back(_msg_ptr); }) - .onError([&](const char* err_message) { - CFATAL(RMQPublisherHandler, - false, - "[rank=%d] message #%d did not get send: %s", - _rank, - _nb_msg, - err_message) - }); + .onError( + [_msg_ptr = msg.data(), + &_nb_msg_ack = _nb_msg_ack, + rank = msg.rank(), + id = msg.id(), + &ptr_mutex = ptr_mutex, + &data_ptrs = this->data_ptrs](const char* err_message) mutable { + const std::lock_guard lock(ptr_mutex); + CFATAL(RMQPublisherHandler, + false, + "[rank=%d] message #%d did not get send: %s", + rank, + id, + err_message) + data_ptrs.push_back(_msg_ptr); + }); } else { WARNING(RMQPublisherHandler, "[rank=%d] The reliable channel was not ready for message #%d.", @@ -1410,8 +1458,84 @@ class RMQPublisherHandler : public AMQP::LibEventHandler _nb_msg++; } + bool waitToEstablish(unsigned ms, int repeat = 1) + { + if (waitFuture(established, ms, repeat)) { + auto status = established.get(); + DBG(RMQPublisherHandler, "Connection Status: %d", status); + return status == CONNECTED; + } + return false; + } + + bool waitToClose(unsigned ms, int repeat = 1) + { + if (waitFuture(closed, ms, repeat)) { + return closed.get() == CLOSED; + } + return false; + } + ~RMQPublisherHandler() = default; + void release_message_buffers() + { + const std::lock_guard lock(ptr_mutex); + for (auto& dp : data_ptrs) { + DBG(RMQPublisherHandler, "deallocate address %p", dp) + ams::ResourceManager::deallocate(dp, AMSResourceType::HOST); + } + data_ptrs.erase(data_ptrs.begin(), data_ptrs.end()); + } + + unsigned unacknowledged() const { return _rchannel->unacknowledged(); } + + void flush() + { + uint32_t tries = 0; + while (auto unAck = _rchannel->unacknowledged()) { + DBG(RMQPublisherHandler, + "Waiting for %lu messages to be acknowledged", + unAck); + + if (++tries > 10) break; + std::this_thread::sleep_for(std::chrono::milliseconds(50 * tries)); + } + } + + // void purge() + // { + // std::promise purge_queue; + // std::future purged; + // purged = purge_queue.get_future(); + // + // _channel->purgeQueue(_queue) + // .onSuccess([&](uint32_t messageCount) { + // DBG(RMQPublisherHandler, + // "Sucessfuly purged queue with (%u) remaining messages", + // messageCount); + // purge_queue.set_value(true); + // }) + // .onError([&](const char* message) { + // DBG(RMQPublisherHandler, + // "Error '%s' when purging queue %s", + // message, + // _queue.c_str()); + // purge_queue.set_value(false); + // }) + // .onFinalize([&]() { + // DBG(RMQPublisherHandler, "Finalizing queue %s", _queue.c_str()) + // }); + // + // if (purged.get()) { + // DBG(RMQPublisherHandler, "Successfull destruction of RMQ queue"); + // return; + // } + // + // DBG(RMQPublisherHandler, "Non-successfull destruction of RMQ queue"); + // } + + private: /** * @brief Method that is called after a TCP connection has been set up, and @@ -1441,7 +1565,8 @@ class RMQPublisherHandler : public AMQP::LibEventHandler error += std::string(ERR_reason_error_string(err)); } error += "]"; - throw std::runtime_error(error); + establish_connection.set_value(FAILED); + return false; } else { DBG(RMQPublisherHandler, "Success logged with ca-chain %s", @@ -1510,6 +1635,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler _queue.c_str()) _rchannel = std::make_shared>(*_channel.get()); + establish_connection.set_value(CONNECTED); }) .onError([&](const char* message) { CFATAL(RMQPublisherHandler, @@ -1519,6 +1645,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler _rank, _queue.c_str(), message) + establish_connection.set_value(FAILED); }); } @@ -1546,10 +1673,10 @@ class RMQPublisherHandler : public AMQP::LibEventHandler virtual void onError(AMQP::TcpConnection* connection, const char* message) override { - DBG(RMQPublisherHandler, - "[rank=%d] fatal error when establishing TCP connection: %s\n", - _rank, - message) + FATAL(RMQPublisherHandler, + "[rank=%d] fatal error on TCP connection: %s\n", + _rank, + message) } /** @@ -1561,6 +1688,20 @@ class RMQPublisherHandler : public AMQP::LibEventHandler { // add your own implementation, like cleanup resources or exit the application DBG(RMQPublisherHandler, "[rank=%d] Connection is detached.\n", _rank) + close_connection.set_value(CLOSED); + } + + bool waitFuture(std::future& future, + unsigned ms, + int repeat) + { + std::chrono::milliseconds span(ms); + int iters = 0; + std::future_status status; + while ((status = future.wait_for(span)) == std::future_status::timeout && + (iters++ < repeat)) + std::future established; + return status == std::future_status::ready; } }; // class RMQPublisherHandler @@ -1569,7 +1710,6 @@ class RMQPublisherHandler : public AMQP::LibEventHandler * @brief Class that manages a RabbitMQ broker and handles connection, event * loop and set up various handlers. */ -template class RMQPublisher { private: @@ -1584,7 +1724,7 @@ class RMQPublisher /** @brief The event loop for sender (usually the default one in libevent) */ std::shared_ptr _loop; /** @brief The handler which contains various callbacks for the sender */ - std::shared_ptr> _handler; + std::shared_ptr _handler; public: RMQPublisher(const RMQPublisher&) = delete; @@ -1629,9 +1769,7 @@ class RMQPublisher event_base_free(event); }); - _handler = std::make_shared>(_loop, - _cacert, - _queue); + _handler = std::make_shared(_loop, _cacert, _queue); _connection = new AMQP::TcpConnection(_handler.get(), address); } @@ -1645,45 +1783,37 @@ class RMQPublisher * @brief Wait that the connection is ready (blocking call) * @return True if the publisher is ready to publish */ - void wait_ready(int ms = 500, int timeout_sec = 30) + bool waitToEstablish(unsigned ms, int repeat = 1) { - // We wait for the connection to be ready - int total_time = 0; - while (!ready_publish()) { - std::this_thread::sleep_for(std::chrono::milliseconds(ms)); - DBG(RMQPublisher, - "[rank=%d] Waiting for connection to be ready...", - _rank) - total_time += ms; - if (total_time > timeout_sec * 1000) { - DBG(RMQPublisher, "[rank=%d] Connection timeout", _rank) - break; - // TODO: if connection is not working -> revert to classic file DB. - } - } + return _handler->waitToEstablish(ms, repeat); } + unsigned unacknowledged() const { return _handler->unacknowledged(); } + + /** * @brief Start the underlying I/O loop (blocking call) */ - void start() - { - event_base_dispatch(_loop.get()); - // We wait for the connection to be ready - wait_ready(); - } + void start() { event_base_dispatch(_loop.get()); } /** * @brief Stop the underlying I/O loop */ void stop() { event_base_loopexit(_loop.get(), NULL); } - void publish(const AMSMessage& message) + void release_messages() { _handler->release_message_buffers(); } + + void publish(AMSMessage&& message) { _handler->publish(std::move(message)); } + + bool close(unsigned ms, int repeat = 1) { - _handler->publish(message); + _handler->flush(); + _connection->close(false); + return _handler->waitToClose(ms, repeat); } - ~RMQPublisher() { delete _connection; } + ~RMQPublisher() {} + }; // class RMQPublisher /** @@ -1757,11 +1887,11 @@ class RabbitMQDB final : public BaseDB /** @brief Represent the ID of the last message sent */ int _msg_tag; /** @brief Publisher sending messages to RMQ server */ - std::shared_ptr> _publisher; + std::shared_ptr _publisher; /** @brief Thread in charge of the publisher */ std::thread _publisher_thread; /** @brief Consumer listening to RMQ and consuming messages */ - std::shared_ptr> _consumer; + std::shared_ptr _consumer; /** @brief Thread in charge of the consumer */ std::thread _consumer_thread; @@ -1865,15 +1995,21 @@ class RabbitMQDB final : public BaseDB is_secure); std::string cacert = rmq_config["rabbitmq-cert"]; - _publisher = std::make_shared>(address, - cacert, - _queue_sender); - _consumer = std::make_shared>(address, - cacert, - _queue_receiver); + _publisher = std::make_shared(address, cacert, _queue_sender); _publisher_thread = std::thread([&]() { _publisher->start(); }); - _consumer_thread = std::thread([&]() { _consumer->start(); }); + + bool status = _publisher->waitToEstablish(100, 10); + if (!status) { + _publisher->stop(); + _publisher_thread.join(); + FATAL(RabbitMQDB, "Could not establish connection"); + } + + //_consumer_thread = std::thread([&]() { _consumer->start(); }); + //_consumer = std::make_shared>(address, + // cacert, + // _queue_receiver); } /** @@ -1898,8 +2034,8 @@ class RabbitMQDB final : public BaseDB inputs.size(), outputs.size()) - auto msg = AMSMessage(_msg_tag, num_elements, inputs, outputs); - _publisher->publish(msg); + _publisher->release_messages(); + _publisher->publish(AMSMessage(_msg_tag, num_elements, inputs, outputs)); _msg_tag++; } @@ -1916,10 +2052,17 @@ class RabbitMQDB final : public BaseDB ~RabbitMQDB() { + + bool status = _publisher->close(100, 10); + CWARNING(RabbitMQDB, !status, "Could not gracefully close TCP connection") + DBG(RabbitMQDB, + "Number of unacknowledged messages are %d", + _publisher->unacknowledged()) _publisher->stop(); - _consumer->stop(); + //_publisher->release_messages(); + //_consumer->stop(); _publisher_thread.join(); - _consumer_thread.join(); + //_consumer_thread.join(); } }; // class RabbitMQDB diff --git a/tests/AMSlib/CMakeLists.txt b/tests/AMSlib/CMakeLists.txt index bae8ef6d..83318b1f 100644 --- a/tests/AMSlib/CMakeLists.txt +++ b/tests/AMSlib/CMakeLists.txt @@ -7,7 +7,7 @@ function (BUILD_TEST exe source) add_executable(${exe} ${source}) target_include_directories(${exe} PRIVATE "${PROJECT_SOURCE_DIR}/src/AMSlib/" umpire ${caliper_INCLUDE_DIR} ${MPI_INCLUDE_PATH}) target_link_directories(${exe} PRIVATE ${AMS_APP_LIB_DIRS}) - target_link_libraries(${exe} PRIVATE AMS umpire MPI::MPI_CXX) + target_link_libraries(${exe} PRIVATE AMS ${AMS_APP_LIBRARIES}) target_compile_definitions(${exe} PRIVATE ${AMS_APP_DEFINES}) if (WITH_CUDA)