Skip to content

Commit

Permalink
Send client notifications from a separate workqueue (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
guhetier authored Feb 18, 2022
1 parent 41f095a commit 4ef8838
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 18 deletions.
45 changes: 35 additions & 10 deletions lib/OperationHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void OperationHandler::SendGuestNotification(std::variant<DisconnectNotif, Signa
}
}

void OperationHandler::NotifyConnectionToClient(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network, DOT11_AUTH_ALGORITHM authAlgo)
void OperationHandler::NotifyConnectionToClientSerialized(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network, DOT11_AUTH_ALGORITHM authAlgo)
{
Log::Info(
L"Notifying a connection to client. Source: %ws, Interface: %ws, Ssid: %ws, Auth Algo: %ws",
Expand All @@ -79,7 +79,7 @@ void OperationHandler::NotifyConnectionToClient(EventSource source, const GUID&
}
}

void OperationHandler::NotifyDisconnectionToClient(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network)
void OperationHandler::NotifyDisconnectionToClientSerialized(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network)
{
Log::Info(
L"Notifying a disconnection to client. Source: %ws, Interface: %ws, Ssid: %ws",
Expand All @@ -93,26 +93,44 @@ void OperationHandler::NotifyDisconnectionToClient(EventSource source, const GUI
}
}

void OperationHandler::NotifyConnectionToClient(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network, DOT11_AUTH_ALGORITHM authAlgo)
{
m_clientNotificationQueue.RunAndWait([this, source, interfaceGuid, network, authAlgo] {
NotifyConnectionToClientSerialized(source, interfaceGuid, network, authAlgo);
});
}

void OperationHandler::NotifyDisconnectionToClient(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network)
{
m_clientNotificationQueue.RunAndWait(
[this, source, interfaceGuid, network] { NotifyDisconnectionToClientSerialized(source, interfaceGuid, network); });
}

void OperationHandler::NotifyGuestConnectRequestProgress(GuestConnectStatus status)
{
Log::Info(L"Notifying guest directed connection progress to the client. Status: %ws", GuestConnectStatusToString(status));
m_clientNotificationQueue.RunAndWait([this, status] {
Log::Info(L"Notifying guest directed connection progress to the client. Status: %ws", GuestConnectStatusToString(status));

if (m_clientCallbacks.OnGuestConnectRequestProgress)
{
m_clientCallbacks.OnGuestConnectRequestProgress(status);
}
if (m_clientCallbacks.OnGuestConnectRequestProgress)
{
m_clientCallbacks.OnGuestConnectRequestProgress(status);
}
});
}

void OperationHandler::OnHostConnection(const GUID& interfaceGuid, const Ssid& ssid, DOT11_AUTH_ALGORITHM authAlgo)
{
// Always notify the client on a new host connection
NotifyConnectionToClient(EventSource::Host, interfaceGuid, ssid, authAlgo);
m_clientNotificationQueue.Run([this, interfaceGuid, ssid, authAlgo] {
NotifyConnectionToClientSerialized(EventSource::Host, interfaceGuid, ssid, authAlgo);
});
}

void OperationHandler::OnHostDisconnection(const GUID& interfaceGuid, const Ssid& ssid)
{
// Notify the client first
NotifyDisconnectionToClient(EventSource::Host, interfaceGuid, ssid);
m_clientNotificationQueue.Run(
[this, interfaceGuid, ssid] { NotifyDisconnectionToClientSerialized(EventSource::Host, interfaceGuid, ssid); });

// If this is a spontaneous disconnection from the host on the interface currently used by the guest,
// send a disconnect notification to the guest and mark it as disconnected
Expand All @@ -134,7 +152,8 @@ void OperationHandler::OnHostSignalQualityChange(const GUID& interfaceGuid, unsi
{
// Only forward notification for the currently connected interface to the guest
m_serializedRunner.Run([this, interfaceGuid, signalQuality] {
if (m_guestConnection && interfaceGuid == m_guestConnection->interfaceGuid) {
if (m_guestConnection && interfaceGuid == m_guestConnection->interfaceGuid)
{
Log::Trace(L"Send Signal quality change notification to the guest, Signal quality: %d", signalQuality);
SendGuestNotification(SignalQualityNotif{Wlansvc::LinkQualityToRssi(signalQuality)});
}
Expand Down Expand Up @@ -403,4 +422,10 @@ ScanResponse OperationHandler::HandleScanRequest(const ScanRequest& scanRequest)
return m_serializedRunner.RunAndWait([&] { return HandleScanRequestSerialized(scanRequest); });
}

void OperationHandler::DrainClientNotifications()
{
// Wait for a task doing nothing: this ensure all previous notification have been processed
m_clientNotificationQueue.RunAndWait([] { return; });
}

} // namespace ProxyWifi
8 changes: 8 additions & 0 deletions lib/OperationHandler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class OperationHandler: private INotificationHandler
void RegisterGuestNotificationCallback(GuestNotificationCallback notificationCallback);
void ClearGuestNotificationCallback();

/// @brief Wait all client notifications have been processed and return
/// Unit test helper
void DrainClientNotifications();

protected:

/// @brief Must be called by the interfaces when they connect to a network
Expand All @@ -77,6 +81,9 @@ class OperationHandler: private INotificationHandler
/// @brief Send a notification to the guest
void SendGuestNotification(std::variant<DisconnectNotif, SignalQualityNotif> notif);

void NotifyConnectionToClientSerialized(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network, DOT11_AUTH_ALGORITHM authAlgo);
void NotifyDisconnectionToClientSerialized(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network);

/// @brief Notify the lib user that the host/guest connected
void NotifyConnectionToClient(EventSource source, const GUID& interfaceGuid, const DOT11_SSID& network, DOT11_AUTH_ALGORITHM authAlgo);

Expand Down Expand Up @@ -124,6 +131,7 @@ class OperationHandler: private INotificationHandler
std::vector<std::unique_ptr<IWlanInterface>> m_wlanInterfaces;

SerializedWorkRunner m_serializedRunner;
SerializedWorkRunner m_clientNotificationQueue;
};

} // namespace ProxyWifi
62 changes: 54 additions & 8 deletions test/TestOpHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ TEST_CASE("Handle disconnect requests with multiple interfaces", "[wlansvcOpHand
}
}

TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandler]")
TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandler][clientNotification]")
{
auto fakeWlansvc =
std::make_shared<Mock::WlanSvcFake>(std::vector{Mock::c_intf1}, std::vector{Mock::c_wpa2PskNetwork, Mock::c_openNetwork});
Expand All @@ -443,6 +443,7 @@ TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandl
SECTION("Notifications on guest initiated operations")
{
auto connectResponse = opHandler->HandleConnectRequest(connectRequest);
opHandler->DrainClientNotifications();
CHECK(connectResponse->result_code == WI_EnumValue(WlanStatus::Success));
CHECK(fakeWlansvc->callCount.connect == 1);
CHECK(hostConnect == 1);
Expand All @@ -451,6 +452,7 @@ TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandl
CHECK(guestDisconnect == 0);

auto disconnectResponse = opHandler->HandleDisconnectRequest(disconnectRequest);
opHandler->DrainClientNotifications();
CHECK(fakeWlansvc->callCount.disconnect == 1);
CHECK(hostConnect == 1);
CHECK(guestConnect == 1);
Expand All @@ -462,13 +464,15 @@ TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandl
{
fakeWlansvc->ConnectHost(Mock::c_intf1, Mock::c_wpa2PskNetwork.bss.ssid);
fakeWlansvc->WaitForNotifComplete();
opHandler->DrainClientNotifications();
CHECK(hostConnect == 1);
CHECK(guestConnect == 0);
CHECK(hostDisconnect == 0);
CHECK(guestDisconnect == 0);

fakeWlansvc->DisconnectHost(Mock::c_intf1);
fakeWlansvc->WaitForNotifComplete();
opHandler->DrainClientNotifications();

CHECK(hostConnect == 1);
CHECK(guestConnect == 0);
Expand All @@ -480,12 +484,14 @@ TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandl
{
fakeWlansvc->ConnectHost(Mock::c_intf1, Mock::c_openNetwork.bss.ssid);
fakeWlansvc->WaitForNotifComplete();
opHandler->DrainClientNotifications();
CHECK(hostConnect == 1);
CHECK(guestConnect == 0);
CHECK(hostDisconnect == 0);
CHECK(guestDisconnect == 0);

auto connectResponse = opHandler->HandleConnectRequest(connectRequest);
opHandler->DrainClientNotifications();
CHECK(connectResponse->result_code == WI_EnumValue(WlanStatus::Success));
CHECK(fakeWlansvc->callCount.connect == 0);
CHECK(hostConnect == 1);
Expand All @@ -494,6 +500,7 @@ TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandl
CHECK(guestDisconnect == 0);

auto disconnectResponse = opHandler->HandleDisconnectRequest(disconnectRequest);
opHandler->DrainClientNotifications();
CHECK(fakeWlansvc->callCount.disconnect == 0);
CHECK(hostConnect == 1);
CHECK(guestConnect == 1);
Expand All @@ -502,6 +509,7 @@ TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandl

fakeWlansvc->DisconnectHost(Mock::c_intf1);
fakeWlansvc->WaitForNotifComplete();
opHandler->DrainClientNotifications();

CHECK(hostConnect == 1);
CHECK(guestConnect == 1);
Expand All @@ -510,7 +518,7 @@ TEST_CASE("Notify client on host connection and disconnection", "[wlansvcOpHandl
}
}

TEST_CASE("Provide the authentication algorithm on host connections", "[wlansvcOpHandler]")
TEST_CASE("Provide the authentication algorithm on host connections", "[wlansvcOpHandler][clientNotification]")
{
auto fakeWlansvc =
std::make_shared<Mock::WlanSvcFake>(std::vector{Mock::c_intf1}, std::vector{Mock::c_wpa2PskNetwork, Mock::c_openNetwork, Mock::c_enterpriseNetwork});
Expand All @@ -523,6 +531,7 @@ TEST_CASE("Provide the authentication algorithm on host connections", "[wlansvcO
{
fakeWlansvc->ConnectHost(Mock::c_intf1, Mock::c_wpa2PskNetwork.bss.ssid);
fakeWlansvc->WaitForNotifComplete();
opHandler->DrainClientNotifications();
REQUIRE(notifParams.size() == 1);
CHECK(notifParams[0].first == EventSource::Host);
CHECK(notifParams[0].second.authAlgo == DOT11_AUTH_ALGO_RSNA_PSK);
Expand All @@ -532,6 +541,7 @@ TEST_CASE("Provide the authentication algorithm on host connections", "[wlansvcO
{
fakeWlansvc->ConnectHost(Mock::c_intf1, Mock::c_openNetwork.bss.ssid);
fakeWlansvc->WaitForNotifComplete();
opHandler->DrainClientNotifications();
REQUIRE(notifParams.size() == 1);
CHECK(notifParams[0].first == EventSource::Host);
CHECK(notifParams[0].second.authAlgo == DOT11_AUTH_ALGO_80211_OPEN);
Expand All @@ -541,6 +551,7 @@ TEST_CASE("Provide the authentication algorithm on host connections", "[wlansvcO
{
fakeWlansvc->ConnectHost(Mock::c_intf1, Mock::c_enterpriseNetwork.bss.ssid);
fakeWlansvc->WaitForNotifComplete();
opHandler->DrainClientNotifications();
REQUIRE(notifParams.size() == 1);
CHECK(notifParams[0].first == EventSource::Host);
CHECK(notifParams[0].second.authAlgo == DOT11_AUTH_ALGO_RSNA_PSK);
Expand All @@ -558,7 +569,7 @@ TEST_CASE("Provide the authentication algorithm on host connections", "[wlansvcO
}
}

TEST_CASE("Notify client for guest directed connection progress", "[wlansvcOpHandler]")
TEST_CASE("Notify client for guest directed connection progress", "[wlansvcOpHandler][clientNotification]")
{
auto fakeWlansvc =
std::make_shared<Mock::WlanSvcFake>(std::vector{Mock::c_intf1}, std::vector{Mock::c_wpa2PskNetwork});
Expand Down Expand Up @@ -593,7 +604,7 @@ TEST_CASE("Notify client for guest directed connection progress", "[wlansvcOpHan
CHECK(succeeded == 1);
}

TEST_CASE("Notification for guest directed connection are in order", "[wlansvcOpHandler]")
TEST_CASE("Notification for guest directed connection are in order", "[wlansvcOpHandler][clientNotification]")
{
auto fakeWlansvc =
std::make_shared<Mock::WlanSvcFake>(std::vector{Mock::c_intf1}, std::vector{Mock::c_wpa2PskNetwork});
Expand All @@ -620,14 +631,13 @@ TEST_CASE("Notification for guest directed connection are in order", "[wlansvcOp
CHECK(notifs == std::vector{TestNotif::ConnectStarting, TestNotif::HostConnected, TestNotif::GuestConnected, TestNotif::ConnectSucceeded});
}

TEST_CASE("Notify client for initially connected networks", "[wlansvcOpHandler]")
TEST_CASE("Notify client for initially connected networks", "[wlansvcOpHandler][clientNotification]")
{
auto fakeWlansvc =
std::make_shared<Mock::WlanSvcFake>(std::vector{Mock::c_intf1}, std::vector{Mock::c_wpa2PskNetwork, Mock::c_openNetwork});

fakeWlansvc->ConnectHost(Mock::c_intf1, Mock::c_wpa2PskNetwork.bss.ssid);
fakeWlansvc->WaitForNotifComplete();

int hostConnect = 0;
int guestConnect = 0;
int hostDisconnect = 0;
Expand All @@ -638,14 +648,50 @@ TEST_CASE("Notify client for initially connected networks", "[wlansvcOpHandler]"
{},
{}},
fakeWlansvc);
opHandler->DrainClientNotifications();

CHECK(hostConnect == 1);
CHECK(guestConnect == 0);
CHECK(hostDisconnect == 0);
CHECK(guestDisconnect == 0);
}

TEST_CASE("Handle graciously non-expected wlansvc notifications", "[wlansvcOpHandler")
TEST_CASE("Initial notifications cannot deadlock a cient", "[wlansvcOpHandler][clientNotification]")
{
auto fakeWlansvc =
std::make_shared<Mock::WlanSvcFake>(std::vector{Mock::c_intf1}, std::vector{Mock::c_wpa2PskNetwork, Mock::c_openNetwork});

fakeWlansvc->ConnectHost(Mock::c_intf1, Mock::c_wpa2PskNetwork.bss.ssid);
fakeWlansvc->WaitForNotifComplete();

wil::srwlock clientLock;
bool noDeadlock = false;
auto opHandler = [&]() {
auto lock = clientLock.lock_exclusive();
return MakeUnitTestOperationHandler(
{[&](auto, auto) {
for (auto i = 0; i < 10; ++i)
{
auto lock = clientLock.try_lock_exclusive();
if (lock)
{
noDeadlock = true;
return;
}
std::this_thread::sleep_for(10ms);
}
},
{},
{},
{}},
fakeWlansvc);
}();

opHandler->DrainClientNotifications();
CHECK(noDeadlock);
}

TEST_CASE("Handle graciously non-expected wlansvc notifications", "[wlansvcOpHandler]")
{
auto fakeWlansvc =
std::make_shared<Mock::WlanSvcFake>(std::vector{Mock::c_intf1}, std::vector{Mock::c_wpa2PskNetwork, Mock::c_openNetwork});
Expand Down Expand Up @@ -801,7 +847,7 @@ TEST_CASE("Ignore notification from other interfaces", "[wlansvcOpHandler][multi
}
}

TEST_CASE("Notifications for fake networks use FakeInterfaceGuid")
TEST_CASE("Notifications for fake networks use FakeInterfaceGuid", "[wlansvcOpHandler][clientNotification]")
{
auto fakeWlansvc = std::make_shared<Mock::WlanSvcFake>();

Expand Down

0 comments on commit 4ef8838

Please sign in to comment.