diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index caf8a83d50d47a..776f2d1a0a7866 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -546,7 +546,7 @@ pub struct Channel { } impl Channel { - fn from_model(value: channel::Model) -> Self { + pub fn from_model(value: channel::Model) -> Self { Channel { id: value.id, visibility: value.visibility, @@ -604,16 +604,14 @@ pub struct RejoinedChannelBuffer { #[derive(Clone)] pub struct JoinRoom { pub room: proto::Room, - pub channel_id: Option, - pub channel_members: Vec, + pub channel: Option, } pub struct RejoinedRoom { pub room: proto::Room, pub rejoined_projects: Vec, pub reshared_projects: Vec, - pub channel_id: Option, - pub channel_members: Vec, + pub channel: Option, } pub struct ResharedProject { @@ -649,8 +647,7 @@ pub struct RejoinedWorktree { pub struct LeftRoom { pub room: proto::Room, - pub channel_id: Option, - pub channel_members: Vec, + pub channel: Option, pub left_projects: HashMap, pub canceled_calls_to_user_ids: Vec, pub deleted: bool, @@ -658,8 +655,7 @@ pub struct LeftRoom { pub struct RefreshedRoom { pub room: proto::Room, - pub channel_id: Option, - pub channel_members: Vec, + pub channel: Option, pub stale_participant_user_ids: Vec, pub canceled_calls_to_user_ids: Vec, } diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index d552f646a00f38..f465d3812a64e3 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -91,7 +91,9 @@ id_type!(NotificationKindId); id_type!(HostedProjectId); /// ChannelRole gives you permissions for both channels and calls. -#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] +#[derive( + Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize, +)] #[sea_orm(rs_type = "String", db_type = "String(None)")] pub enum ChannelRole { /// Admin can read/write and change permissions. diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index f64f5d2588772d..3f168e08544cac 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -45,11 +45,7 @@ impl Database { name: &str, parent_channel_id: Option, admin_id: UserId, - ) -> Result<( - Channel, - Option, - Vec, - )> { + ) -> Result<(channel::Model, Option)> { let name = Self::sanitize_channel_name(name)?; self.transaction(move |tx| async move { let mut parent = None; @@ -90,12 +86,7 @@ impl Database { ); } - let channel_members = channel_member::Entity::find() - .filter(channel_member::Column::ChannelId.eq(channel.root_id())) - .all(&*tx) - .await?; - - Ok((Channel::from_model(channel), membership, channel_members)) + Ok((channel, membership)) }) .await } @@ -181,7 +172,7 @@ impl Database { channel_id: ChannelId, visibility: ChannelVisibility, admin_id: UserId, - ) -> Result<(Channel, Vec)> { + ) -> Result { self.transaction(move |tx| async move { let channel = self.get_channel_internal(channel_id, &tx).await?; self.check_user_is_channel_admin(&channel, admin_id, &tx) @@ -214,12 +205,7 @@ impl Database { model.visibility = ActiveValue::Set(visibility); let channel = model.update(&*tx).await?; - let channel_members = channel_member::Entity::find() - .filter(channel_member::Column::ChannelId.eq(channel.root_id())) - .all(&*tx) - .await?; - - Ok((Channel::from_model(channel), channel_members)) + Ok(channel) }) .await } @@ -245,21 +231,12 @@ impl Database { &self, channel_id: ChannelId, user_id: UserId, - ) -> Result<(Vec, Vec)> { + ) -> Result<(ChannelId, Vec)> { self.transaction(move |tx| async move { let channel = self.get_channel_internal(channel_id, &tx).await?; self.check_user_is_channel_admin(&channel, user_id, &tx) .await?; - let members_to_notify: Vec = channel_member::Entity::find() - .filter(channel_member::Column::ChannelId.eq(channel.root_id())) - .select_only() - .column(channel_member::Column::UserId) - .distinct() - .into_values::<_, QueryUserIds>() - .all(&*tx) - .await?; - let channels_to_remove = self .get_channel_descendants_excluding_self([&channel], &tx) .await? @@ -273,7 +250,7 @@ impl Database { .exec(&*tx) .await?; - Ok((channels_to_remove, members_to_notify)) + Ok((channel.root_id(), channels_to_remove)) }) .await } @@ -343,7 +320,7 @@ impl Database { channel_id: ChannelId, admin_id: UserId, new_name: &str, - ) -> Result<(Channel, Vec)> { + ) -> Result { self.transaction(move |tx| async move { let new_name = Self::sanitize_channel_name(new_name)?.to_string(); @@ -355,12 +332,7 @@ impl Database { model.name = ActiveValue::Set(new_name.clone()); let channel = model.update(&*tx).await?; - let channel_members = channel_member::Entity::find() - .filter(channel_member::Column::ChannelId.eq(channel.root_id())) - .all(&*tx) - .await?; - - Ok((Channel::from_model(channel), channel_members)) + Ok(channel) }) .await } @@ -984,7 +956,7 @@ impl Database { channel_id: ChannelId, new_parent_id: ChannelId, admin_id: UserId, - ) -> Result<(Vec, Vec)> { + ) -> Result<(ChannelId, Vec)> { self.transaction(|tx| async move { let channel = self.get_channel_internal(channel_id, &tx).await?; self.check_user_is_channel_admin(&channel, admin_id, &tx) @@ -1039,12 +1011,7 @@ impl Database { .map(|c| Channel::from_model(c)) .collect::>(); - let channel_members = channel_member::Entity::find() - .filter(channel_member::Column::ChannelId.eq(root_id)) - .all(&*tx) - .await?; - - Ok((channels, channel_members)) + Ok((root_id, channels)) }) .await } diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index dcb31266dfe1ac..62289cdeaa96ae 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -52,12 +52,7 @@ impl Database { ); let (channel, room) = self.get_channel_room(room_id, &tx).await?; - let channel_members; - if let Some(channel) = &channel { - channel_members = self.get_channel_participants(channel, &tx).await?; - } else { - channel_members = Vec::new(); - + if channel.is_none() { // Delete the room if it becomes empty. if room.participants.is_empty() { project::Entity::delete_many() @@ -70,8 +65,7 @@ impl Database { Ok(RefreshedRoom { room, - channel_id: channel.map(|channel| channel.id), - channel_members, + channel, stale_participant_user_ids, canceled_calls_to_user_ids, }) @@ -349,8 +343,7 @@ impl Database { let room = self.get_room(room_id, &tx).await?; Ok(JoinRoom { room, - channel_id: None, - channel_members: vec![], + channel: None, }) }) .await @@ -446,11 +439,9 @@ impl Database { let (channel, room) = self.get_channel_room(room_id, &tx).await?; let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?; - let channel_members = self.get_channel_participants(&channel, tx).await?; Ok(JoinRoom { room, - channel_id: Some(channel.id), - channel_members, + channel: Some(channel), }) } @@ -736,16 +727,10 @@ impl Database { } let (channel, room) = self.get_channel_room(room_id, &tx).await?; - let channel_members = if let Some(channel) = &channel { - self.get_channel_participants(&channel, &tx).await? - } else { - Vec::new() - }; Ok(RejoinedRoom { room, - channel_id: channel.map(|channel| channel.id), - channel_members, + channel, rejoined_projects, reshared_projects, }) @@ -902,15 +887,9 @@ impl Database { false }; - let channel_members = if let Some(channel) = &channel { - self.get_channel_participants(channel, &tx).await? - } else { - Vec::new() - }; let left_room = LeftRoom { room, - channel_id: channel.map(|channel| channel.id), - channel_members, + channel, left_projects, canceled_calls_to_user_ids, deleted, diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 54be002c41c3bc..ad8f3467ecbae5 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -109,10 +109,9 @@ async fn test_channels(db: &Arc) { assert!(db.get_channel(crdb_id, a_id).await.is_err()); // Remove a channel tree - let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap(); + let (_, mut channel_ids) = db.delete_channel(rust_id, a_id).await.unwrap(); channel_ids.sort(); assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]); - assert_eq!(user_ids, &[a_id]); assert!(db.get_channel(rust_id, a_id).await.is_err()); assert!(db.get_channel(cargo_id, a_id).await.is_err()); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 6700ad8ccce787..a07b33921ecf64 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,10 +3,10 @@ mod connection_pool; use crate::{ auth::{self, Impersonator}, db::{ - self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database, - InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, ProjectId, - RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, User, - UserId, + self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, + Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, + ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, + User, UserId, }, executor::Executor, AppState, Error, Result, @@ -351,14 +351,8 @@ impl Server { "refreshed room" ); room_updated(&refreshed_room.room, &peer); - if let Some(channel_id) = refreshed_room.channel_id { - channel_updated( - channel_id, - &refreshed_room.room, - &refreshed_room.channel_members, - &peer, - &pool.lock(), - ); + if let Some(channel) = refreshed_room.channel.as_ref() { + channel_updated(channel, &refreshed_room.room, &peer, &pool.lock()); } contacts_to_update .extend(refreshed_room.stale_participant_user_ids.iter().copied()); @@ -699,6 +693,9 @@ impl Server { { let mut pool = self.connection_pool.lock(); pool.add_connection(connection_id, user.id, user.admin, zed_version); + for membership in &channels_for_user.channel_memberships { + pool.subscribe_to_channel(user.id, membership.channel_id, membership.role) + } self.peer.send( connection_id, build_initial_contacts_update(contacts, &pool), @@ -1148,8 +1145,7 @@ async fn rejoin_room( session: Session, ) -> Result<()> { let room; - let channel_id; - let channel_members; + let channel; { let mut rejoined_room = session .db() @@ -1315,15 +1311,13 @@ async fn rejoin_room( let rejoined_room = rejoined_room.into_inner(); room = rejoined_room.room; - channel_id = rejoined_room.channel_id; - channel_members = rejoined_room.channel_members; + channel = rejoined_room.channel; } - if let Some(channel_id) = channel_id { + if let Some(channel) = channel { channel_updated( - channel_id, + &channel, &room, - &channel_members, &session.peer, &*session.connection_pool().await, ); @@ -2427,31 +2421,39 @@ async fn create_channel( let db = session.db().await; let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); - let (channel, owner, channel_members) = db + let (channel, membership) = db .create_channel(&request.name, parent_id, session.user_id) .await?; + let root_id = channel.root_id(); + let channel = Channel::from_model(channel); + response.send(proto::CreateChannelResponse { channel: Some(channel.to_proto()), parent_id: request.parent_id, })?; - let connection_pool = session.connection_pool().await; - if let Some(owner) = owner { + let mut connection_pool = session.connection_pool().await; + if let Some(membership) = membership { + connection_pool.subscribe_to_channel( + membership.user_id, + membership.channel_id, + membership.role, + ); let update = proto::UpdateUserChannels { channel_memberships: vec![proto::ChannelMembership { - channel_id: owner.channel_id.to_proto(), - role: owner.role.into(), + channel_id: membership.channel_id.to_proto(), + role: membership.role.into(), }], ..Default::default() }; - for connection_id in connection_pool.user_connection_ids(owner.user_id) { + for connection_id in connection_pool.user_connection_ids(membership.user_id) { session.peer.send(connection_id, update.clone())?; } } - for channel_member in channel_members { - if !channel_member.role.can_see_channel(channel.visibility) { + for (connection_id, role) in connection_pool.channel_connection_ids(root_id) { + if !role.can_see_channel(channel.visibility) { continue; } @@ -2459,9 +2461,7 @@ async fn create_channel( channels: vec![channel.to_proto()], ..Default::default() }; - for connection_id in connection_pool.user_connection_ids(channel_member.user_id) { - session.peer.send(connection_id, update.clone())?; - } + session.peer.send(connection_id, update.clone())?; } Ok(()) @@ -2476,7 +2476,7 @@ async fn delete_channel( let db = session.db().await; let channel_id = request.channel_id; - let (removed_channels, member_ids) = db + let (root_channel, removed_channels) = db .delete_channel(ChannelId::from_proto(channel_id), session.user_id) .await?; response.send(proto::Ack {})?; @@ -2488,10 +2488,8 @@ async fn delete_channel( .extend(removed_channels.into_iter().map(|id| id.to_proto())); let connection_pool = session.connection_pool().await; - for member_id in member_ids { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } + for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) { + session.peer.send(connection_id, update.clone())?; } Ok(()) @@ -2551,9 +2549,9 @@ async fn remove_channel_member( .remove_channel_member(channel_id, member_id, session.user_id) .await?; - let connection_pool = &session.connection_pool().await; + let mut connection_pool = session.connection_pool().await; notify_membership_updated( - &connection_pool, + &mut connection_pool, membership_update, member_id, &session.peer, @@ -2588,25 +2586,33 @@ async fn set_channel_visibility( let channel_id = ChannelId::from_proto(request.channel_id); let visibility = request.visibility().into(); - let (channel, channel_members) = db + let channel_model = db .set_channel_visibility(channel_id, visibility, session.user_id) .await?; + let root_id = channel_model.root_id(); + let channel = Channel::from_model(channel_model); - let connection_pool = session.connection_pool().await; - for member in channel_members { - let update = if member.role.can_see_channel(channel.visibility) { + let mut connection_pool = session.connection_pool().await; + for (user_id, role) in connection_pool + .channel_user_ids(root_id) + .collect::>() + .into_iter() + { + let update = if role.can_see_channel(channel.visibility) { + connection_pool.subscribe_to_channel(user_id, channel_id, role); proto::UpdateChannels { channels: vec![channel.to_proto()], ..Default::default() } } else { + connection_pool.unsubscribe_from_channel(&user_id, &channel_id); proto::UpdateChannels { delete_channels: vec![channel.id.to_proto()], ..Default::default() } }; - for connection_id in connection_pool.user_connection_ids(member.user_id) { + for connection_id in connection_pool.user_connection_ids(user_id) { session.peer.send(connection_id, update.clone())?; } } @@ -2635,9 +2641,9 @@ async fn set_channel_member_role( match result { db::SetMemberRoleResult::MembershipUpdated(membership_update) => { - let connection_pool = session.connection_pool().await; + let mut connection_pool = session.connection_pool().await; notify_membership_updated( - &connection_pool, + &mut connection_pool, membership_update, member_id, &session.peer, @@ -2671,24 +2677,23 @@ async fn rename_channel( ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - let (channel, channel_members) = db + let channel_model = db .rename_channel(channel_id, session.user_id, &request.name) .await?; + let root_id = channel_model.root_id(); + let channel = Channel::from_model(channel_model); response.send(proto::RenameChannelResponse { channel: Some(channel.to_proto()), })?; let connection_pool = session.connection_pool().await; - for channel_member in channel_members { - if !channel_member.role.can_see_channel(channel.visibility) { - continue; - } - let update = proto::UpdateChannels { - channels: vec![channel.to_proto()], - ..Default::default() - }; - for connection_id in connection_pool.user_connection_ids(channel_member.user_id) { + let update = proto::UpdateChannels { + channels: vec![channel.to_proto()], + ..Default::default() + }; + for (connection_id, role) in connection_pool.channel_connection_ids(root_id) { + if role.can_see_channel(channel.visibility) { session.peer.send(connection_id, update.clone())?; } } @@ -2705,18 +2710,18 @@ async fn move_channel( let channel_id = ChannelId::from_proto(request.channel_id); let to = ChannelId::from_proto(request.to); - let (channels, channel_members) = session + let (root_id, channels) = session .db() .await .move_channel(channel_id, to, session.user_id) .await?; let connection_pool = session.connection_pool().await; - for member in channel_members { + for (connection_id, role) in connection_pool.channel_connection_ids(root_id) { let channels = channels .iter() .filter_map(|channel| { - if member.role.can_see_channel(channel.visibility) { + if role.can_see_channel(channel.visibility) { Some(channel.to_proto()) } else { None @@ -2732,9 +2737,7 @@ async fn move_channel( ..Default::default() }; - for connection_id in connection_pool.user_connection_ids(member.user_id) { - session.peer.send(connection_id, update.clone())?; - } + session.peer.send(connection_id, update.clone())?; } response.send(Ack {})?; @@ -2771,10 +2774,10 @@ async fn respond_to_channel_invite( .respond_to_channel_invite(channel_id, session.user_id, request.accept) .await?; - let connection_pool = session.connection_pool().await; + let mut connection_pool = session.connection_pool().await; if let Some(membership_update) = membership_update { notify_membership_updated( - &connection_pool, + &mut connection_pool, membership_update, session.user_id, &session.peer, @@ -2866,14 +2869,17 @@ async fn join_channel_internal( response.send(proto::JoinRoomResponse { room: Some(joined_room.room.clone()), - channel_id: joined_room.channel_id.map(|id| id.to_proto()), + channel_id: joined_room + .channel + .as_ref() + .map(|channel| channel.id.to_proto()), live_kit_connection_info, })?; - let connection_pool = session.connection_pool().await; + let mut connection_pool = session.connection_pool().await; if let Some(membership_updated) = membership_updated { notify_membership_updated( - &connection_pool, + &mut connection_pool, membership_updated, session.user_id, &session.peer, @@ -2886,9 +2892,10 @@ async fn join_channel_internal( }; channel_updated( - channel_id, + &joined_room + .channel + .ok_or_else(|| anyhow!("channel not returned"))?, &joined_room.room, - &joined_room.channel_members, &session.peer, &*session.connection_pool().await, ); @@ -3403,11 +3410,18 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { } fn notify_membership_updated( - connection_pool: &ConnectionPool, + connection_pool: &mut ConnectionPool, result: MembershipUpdated, user_id: UserId, peer: &Peer, ) { + for membership in &result.new_channels.channel_memberships { + connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role) + } + for channel_id in &result.removed_channels { + connection_pool.unsubscribe_from_channel(&user_id, channel_id) + } + let user_channels_update = proto::UpdateUserChannels { channel_memberships: result .new_channels @@ -3420,6 +3434,7 @@ fn notify_membership_updated( .collect(), ..Default::default() }; + let mut update = build_channels_update(result.new_channels, vec![]); update.delete_channels = result .removed_channels @@ -3533,9 +3548,8 @@ fn room_updated(room: &proto::Room, peer: &Peer) { } fn channel_updated( - channel_id: ChannelId, + channel: &db::channel::Model, room: &proto::Room, - channel_members: &[UserId], peer: &Peer, pool: &ConnectionPool, ) { @@ -3547,15 +3561,16 @@ fn channel_updated( broadcast( None, - channel_members - .iter() - .flat_map(|user_id| pool.user_connection_ids(*user_id)), + pool.channel_connection_ids(channel.root_id()) + .filter_map(|(channel_id, role)| { + role.can_see_channel(channel.visibility).then(|| channel_id) + }), |peer_id| { peer.send( peer_id, proto::UpdateChannels { channel_participants: vec![proto::ChannelParticipants { - channel_id: channel_id.to_proto(), + channel_id: channel.id.to_proto(), participant_user_ids: participants.clone(), }], ..Default::default() @@ -3608,8 +3623,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { let live_kit_room; let delete_live_kit_room; let room; - let channel_members; - let channel_id; + let channel; if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? { contacts_to_update.insert(session.user_id); @@ -3623,19 +3637,17 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { live_kit_room = mem::take(&mut left_room.room.live_kit_room); delete_live_kit_room = left_room.deleted; room = mem::take(&mut left_room.room); - channel_members = mem::take(&mut left_room.channel_members); - channel_id = left_room.channel_id; + channel = mem::take(&mut left_room.channel); room_updated(&room, &session.peer); } else { return Ok(()); } - if let Some(channel_id) = channel_id { + if let Some(channel) = channel { channel_updated( - channel_id, + &channel, &room, - &channel_members, &session.peer, &*session.connection_pool().await, ); diff --git a/crates/collab/src/rpc/connection_pool.rs b/crates/collab/src/rpc/connection_pool.rs index 2d282903737118..acbd62c00893da 100644 --- a/crates/collab/src/rpc/connection_pool.rs +++ b/crates/collab/src/rpc/connection_pool.rs @@ -1,6 +1,6 @@ -use crate::db::UserId; +use crate::db::{ChannelId, ChannelRole, UserId}; use anyhow::{anyhow, Result}; -use collections::{BTreeMap, HashSet}; +use collections::{BTreeMap, HashMap, HashSet}; use rpc::ConnectionId; use serde::Serialize; use tracing::instrument; @@ -10,6 +10,7 @@ use util::SemanticVersion; pub struct ConnectionPool { connections: BTreeMap, connected_users: BTreeMap, + channels: ChannelPool, } #[derive(Default, Serialize)] @@ -47,6 +48,7 @@ impl ConnectionPool { pub fn reset(&mut self) { self.connections.clear(); self.connected_users.clear(); + self.channels.clear(); } #[instrument(skip(self))] @@ -81,6 +83,7 @@ impl ConnectionPool { connected_user.connection_ids.remove(&connection_id); if connected_user.connection_ids.is_empty() { self.connected_users.remove(&user_id); + self.channels.remove_user(&user_id); } self.connections.remove(&connection_id).unwrap(); Ok(()) @@ -110,6 +113,38 @@ impl ConnectionPool { .copied() } + pub fn channel_user_ids( + &self, + channel_id: ChannelId, + ) -> impl Iterator + '_ { + self.channels.users_to_notify(channel_id) + } + + pub fn channel_connection_ids( + &self, + channel_id: ChannelId, + ) -> impl Iterator + '_ { + self.channels + .users_to_notify(channel_id) + .flat_map(|(user_id, role)| { + self.user_connection_ids(user_id) + .map(move |connection_id| (connection_id, role)) + }) + } + + pub fn subscribe_to_channel( + &mut self, + user_id: UserId, + channel_id: ChannelId, + role: ChannelRole, + ) { + self.channels.subscribe(user_id, channel_id, role); + } + + pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) { + self.channels.unsubscribe(user_id, channel_id); + } + pub fn is_user_online(&self, user_id: UserId) -> bool { !self .connected_users @@ -140,3 +175,70 @@ impl ConnectionPool { } } } + +#[derive(Default, Serialize)] +pub struct ChannelPool { + by_user: HashMap>, + by_channel: HashMap>, +} + +impl ChannelPool { + pub fn clear(&mut self) { + self.by_user.clear(); + self.by_channel.clear(); + } + + pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) { + self.by_user + .entry(user_id) + .or_default() + .insert(channel_id, role); + self.by_channel + .entry(channel_id) + .or_default() + .insert(user_id); + } + + pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) { + if let Some(channels) = self.by_user.get_mut(user_id) { + channels.remove(channel_id); + if channels.is_empty() { + self.by_user.remove(user_id); + } + } + if let Some(users) = self.by_channel.get_mut(channel_id) { + users.remove(user_id); + if users.is_empty() { + self.by_channel.remove(channel_id); + } + } + } + + pub fn remove_user(&mut self, user_id: &UserId) { + if let Some(channels) = self.by_user.remove(&user_id) { + for channel_id in channels.keys() { + self.unsubscribe(user_id, &channel_id) + } + } + } + + pub fn users_to_notify( + &self, + channel_id: ChannelId, + ) -> impl '_ + Iterator { + self.by_channel + .get(&channel_id) + .into_iter() + .flat_map(move |users| { + users.iter().flat_map(move |user_id| { + Some(( + *user_id, + self.by_user + .get(user_id) + .and_then(|channels| channels.get(&channel_id)) + .copied()?, + )) + }) + }) + } +}