From 2cba0bed5be8a25e1635b7cde0a98285f64b5469 Mon Sep 17 00:00:00 2001 From: ismaelsadeeq Date: Thu, 22 Aug 2024 15:26:34 +0100 Subject: [PATCH] [wallet]: lock wallet context before adding/removing wallet settings - Also add test that verifies that no race condition will occur during the process. Co-authored-by: furszy --- src/wallet/test/wallet_tests.cpp | 39 ++++++++++++++++++++++++++++++++ src/wallet/wallet.cpp | 31 +++++++++++++------------ src/wallet/wallet.h | 6 +++-- 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/src/wallet/test/wallet_tests.cpp b/src/wallet/test/wallet_tests.cpp index 44ffddb1687016..01b98c28c75f97 100644 --- a/src/wallet/test/wallet_tests.cpp +++ b/src/wallet/test/wallet_tests.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -329,6 +330,44 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) } } +// This test verifies that wallet settings can be added and removed +// Concurrently, ensuring no race conditions occur during either process. +BOOST_FIXTURE_TEST_CASE(write_wallet_settings_concurrently, TestingSetup) +{ + WalletContext context; + context.chain = m_node.chain.get(); + const auto NUM_WALLETS{5}; + std::vector threads; + threads.reserve(NUM_WALLETS); + // Since we're counting the number of wallets, ensure we start without any. + BOOST_REQUIRE(context.chain->getRwSetting("wallet").isNull()); + + // Add NUM_WALLETS wallets concurrently, ensure we end up with NUM_WALLETS. + { + for (int i{0}; i < NUM_WALLETS; i++) { + threads.emplace_back([i, &context] { + Assert(AddWalletSetting(context, strprintf("wallet_%d", i))); + }); + } + for (auto& t : threads) t.join(); + auto wallets = context.chain->getRwSetting("wallet"); + BOOST_CHECK_EQUAL(wallets.getValues().size(), NUM_WALLETS); + } + threads.clear(); + + // Remove NUM_WALLETS wallets concurrently, ensure we end up with 0 wallets. + { + for (int i{0}; i < NUM_WALLETS; i++) { + threads.emplace_back([i, &context] { + Assert(RemoveWalletSetting(context, strprintf("wallet_%d", i))); + }); + } + for (auto& t : threads) t.join(); + auto wallets = context.chain->getRwSetting("wallet"); + BOOST_CHECK_EQUAL(wallets.getValues().size(), 0); + } +} + // Check that GetImmatureCredit() returns a newly calculated value instead of // the cached value after a MarkDirty() call. // diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 5584b43520aa68..b6e714decb05e6 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -91,38 +91,40 @@ using util::ToString; namespace wallet { -bool AddWalletSetting(interfaces::Chain& chain, const std::string& wallet_name) +bool AddWalletSetting(WalletContext& context, const std::string& wallet_name) { - common::SettingsValue setting_value = chain.getRwSetting("wallet"); + LOCK(context.wallets_mutex); + common::SettingsValue setting_value = context.chain->getRwSetting("wallet"); if (!setting_value.isArray()) setting_value.setArray(); for (const common::SettingsValue& value : setting_value.getValues()) { if (value.isStr() && value.get_str() == wallet_name) return true; } setting_value.push_back(wallet_name); - return chain.updateRwSetting("wallet", setting_value); + return context.chain->updateRwSetting("wallet", setting_value); } -bool RemoveWalletSetting(interfaces::Chain& chain, const std::string& wallet_name) +bool RemoveWalletSetting(WalletContext& context, const std::string& wallet_name) { - common::SettingsValue setting_value = chain.getRwSetting("wallet"); + LOCK(context.wallets_mutex); + common::SettingsValue setting_value = context.chain->getRwSetting("wallet"); if (!setting_value.isArray()) return true; common::SettingsValue new_value(common::SettingsValue::VARR); for (const common::SettingsValue& value : setting_value.getValues()) { if (!value.isStr() || value.get_str() != wallet_name) new_value.push_back(value); } if (new_value.size() == setting_value.size()) return true; - return chain.updateRwSetting("wallet", new_value); + return context.chain->updateRwSetting("wallet", new_value); } -static void UpdateWalletSetting(interfaces::Chain& chain, +static void UpdateWalletSetting(WalletContext& context, const std::string& wallet_name, std::optional load_on_startup, std::vector& warnings) { if (!load_on_startup) return; - if (load_on_startup.value() && !AddWalletSetting(chain, wallet_name)) { + if (load_on_startup.value() && !AddWalletSetting(context, wallet_name)) { warnings.emplace_back(Untranslated("Wallet load on startup setting could not be updated, so wallet may not be loaded next node startup.")); - } else if (!load_on_startup.value() && !RemoveWalletSetting(chain, wallet_name)) { + } else if (!load_on_startup.value() && !RemoveWalletSetting(context, wallet_name)) { warnings.emplace_back(Untranslated("Wallet load on startup setting could not be updated, so wallet may still be loaded next node startup.")); } } @@ -157,7 +159,6 @@ bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet { assert(wallet); - interfaces::Chain& chain = wallet->chain(); std::string name = wallet->GetName(); // Unregister with the validation interface which also drops shared pointers. @@ -172,7 +173,7 @@ bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet wallet->NotifyUnload(); // Write the wallet setting - UpdateWalletSetting(chain, name, load_on_start, warnings); + UpdateWalletSetting(context, name, load_on_start, warnings); return true; } @@ -293,7 +294,7 @@ std::shared_ptr LoadWalletInternal(WalletContext& context, const std::s wallet->postInitProcess(); // Write the wallet setting - UpdateWalletSetting(*context.chain, name, load_on_start, warnings); + UpdateWalletSetting(context, name, load_on_start, warnings); return wallet; } catch (const std::runtime_error& e) { @@ -474,7 +475,7 @@ std::shared_ptr CreateWallet(WalletContext& context, const std::string& wallet->postInitProcess(); // Write the wallet settings - UpdateWalletSetting(*context.chain, name, load_on_start, warnings); + UpdateWalletSetting(context, name, load_on_start, warnings); // Legacy wallets are being deprecated, warn if a newly created wallet is legacy if (!(wallet_creation_flags & WALLET_FLAG_DESCRIPTORS)) { @@ -4324,7 +4325,7 @@ bool DoMigration(CWallet& wallet, WalletContext& context, bilingual_str& error, } // Add the wallet to settings - UpdateWalletSetting(*context.chain, wallet_name, /*load_on_startup=*/true, warnings); + UpdateWalletSetting(context, wallet_name, /*load_on_startup=*/true, warnings); } if (data->solvable_descs.size() > 0) { wallet.WalletLogPrintf("Making a new watchonly wallet containing the unwatched solvable scripts\n"); @@ -4361,7 +4362,7 @@ bool DoMigration(CWallet& wallet, WalletContext& context, bilingual_str& error, } // Add the wallet to settings - UpdateWalletSetting(*context.chain, wallet_name, /*load_on_startup=*/true, warnings); + UpdateWalletSetting(context, wallet_name, /*load_on_startup=*/true, warnings); } } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 3ea1cf48b22bb3..97e70f636ccefb 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -19,6 +19,7 @@ #include