From 6208743c13d55fd1fdaeac832d212b1ff2530657 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sun, 4 Feb 2024 21:46:21 -0700 Subject: [PATCH] Add streaming responses to the RPC system Co-authored-by: Antonio Resurrected this from some assistant work I did in Spring of 2023. --- crates/client/src/client.rs | 35 ++++++++- crates/collab/src/lib.rs | 32 ++++++++ crates/collab/src/rpc.rs | 52 +++++++++++++ crates/rpc/proto/zed.proto | 7 +- crates/rpc/src/peer.rs | 141 +++++++++++++++++++++++++++++++++++- crates/rpc/src/proto.rs | 1 + 6 files changed, 258 insertions(+), 10 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index ff8adc96607b85..f5e1900190aac1 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -11,7 +11,7 @@ use async_tungstenite::tungstenite::{ http::{Request, StatusCode}, }; use futures::{ - channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, + channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, }; use gpui::{ @@ -35,7 +35,10 @@ use std::{ future::Future, marker::PhantomData, path::PathBuf, - sync::{atomic::AtomicU64, Arc, Weak}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Weak, + }, time::{Duration, Instant}, }; use telemetry::Telemetry; @@ -439,7 +442,7 @@ impl Client { } pub fn id(&self) -> u64 { - self.id.load(std::sync::atomic::Ordering::SeqCst) + self.id.load(Ordering::SeqCst) } pub fn http_client(&self) -> Arc { @@ -447,7 +450,7 @@ impl Client { } pub fn set_id(&self, id: u64) -> &Self { - self.id.store(id, std::sync::atomic::Ordering::SeqCst); + self.id.store(id, Ordering::SeqCst); self } @@ -1256,6 +1259,30 @@ impl Client { .map_ok(|envelope| envelope.payload) } + pub fn request_stream( + &self, + request: T, + ) -> impl Future>>> { + let client_id = self.id.load(Ordering::SeqCst); + log::debug!( + "rpc request start. client_id:{}. name:{}", + client_id, + T::NAME + ); + let response = self + .connection_id() + .map(|conn_id| self.peer.request_stream(conn_id, request)); + async move { + let response = response?.await; + log::debug!( + "rpc request finish. client_id:{}. name:{}", + client_id, + T::NAME + ); + response + } + } + pub fn request_envelope( &self, request: T, diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index aba9bd75d1f0aa..7993c750d3aaf4 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -8,6 +8,7 @@ pub mod rpc; #[cfg(test)] mod tests; +use ::rpc::{proto, ErrorExt}; use axum::{http::StatusCode, response::IntoResponse}; use db::Database; use executor::Executor; @@ -22,6 +23,37 @@ pub enum Error { Internal(anyhow::Error), } +impl ErrorExt for Error { + fn error_code(&self) -> proto::ErrorCode { + match self { + Error::Internal(anyhow_error) => anyhow_error.error_code(), + _ => proto::ErrorCode::Internal, + } + } + + fn error_tag(&self, k: &str) -> Option<&str> { + match self { + Error::Internal(anyhow_error) => anyhow_error.error_tag(k), + _ => None, + } + } + + fn to_proto(&self) -> proto::Error { + match self { + Error::Internal(anyhow_error) => anyhow_error.to_proto(), + _ => proto::Error { + message: self.to_string(), + code: self.error_code() as i32, + tags: Vec::new(), + }, + } + } + + fn cloned(&self) -> anyhow::Error { + todo!() + } +} + impl From for Error { fn from(error: anyhow::Error) -> Self { Self::Internal(error) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c97c283b2fda64..560958f722639f 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -100,6 +100,24 @@ impl Response { } } +struct StreamingResponse { + peer: Arc, + receipt: Receipt, +} + +impl StreamingResponse { + fn send(&self, payload: R::Response) -> Result<()> { + self.peer.respond(self.receipt, payload)?; + Ok(()) + } +} + +impl Drop for StreamingResponse { + fn drop(&mut self) { + self.peer.end_stream(self.receipt).trace_err(); + } +} + #[derive(Clone)] struct Session { zed_environment: Arc, @@ -554,6 +572,40 @@ impl Server { }) } + fn add_streaming_request_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(M, StreamingResponse, Session) -> Fut, + Fut: Send + Future>, + M: RequestMessage, + { + let handler = Arc::new(handler); + self.add_handler(move |envelope, session| { + let receipt = envelope.receipt(); + let handler = handler.clone(); + async move { + let peer = session.peer.clone(); + let responded = Arc::new(AtomicBool::default()); + let response = StreamingResponse { + peer: peer.clone(), + receipt, + }; + match (handler)(envelope.payload, response, session).await { + Ok(()) => { + if responded.load(std::sync::atomic::Ordering::SeqCst) { + Ok(()) + } else { + Err(anyhow!("handler did not send a response"))? + } + } + Err(error) => { + peer.respond_with_error(receipt, error.to_proto())?; + Err(error) + } + } + } + }) + } + pub fn handle_connection( self: &Arc, connection: Connection, diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 3a513902e5a5cd..074c9a90fade00 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -1,7 +1,7 @@ syntax = "proto3"; package zed.messages; -// Looking for a number? Search "// Current max" +// Looking for a number? Search "// current max" message PeerId { uint32 owner_id = 1; @@ -18,6 +18,7 @@ message Envelope { Error error = 6; Ping ping = 7; Test test = 8; + EndStream end_stream = 158; // current max CreateRoom create_room = 9; CreateRoomResponse create_room_response = 10; @@ -183,7 +184,7 @@ message Envelope { LspExtExpandMacroResponse lsp_ext_expand_macro_response = 155; SetRoomParticipantRole set_room_participant_role = 156; - UpdateUserChannels update_user_channels = 157; // current max + UpdateUserChannels update_user_channels = 157; } } @@ -219,6 +220,8 @@ enum ErrorCode { UnsharedItem = 12; } +message EndStream {} + message Test { uint64 id = 1; } diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 9d789bd3d01aef..8b61ff156ff6fa 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -9,14 +9,15 @@ use collections::HashMap; use futures::{ channel::{mpsc, oneshot}, stream::BoxStream, - FutureExt, SinkExt, StreamExt, TryFutureExt, + FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, }; use parking_lot::{Mutex, RwLock}; use serde::{ser::SerializeStruct, Serialize}; -use std::{fmt, sync::atomic::Ordering::SeqCst}; use std::{ + fmt, future, future::Future, marker::PhantomData, + sync::atomic::Ordering::SeqCst, sync::{ atomic::{self, AtomicU32}, Arc, @@ -113,6 +114,15 @@ pub struct ConnectionState { #[serde(skip)] response_channels: Arc)>>>>>, + #[allow(clippy::type_complexity)] + #[serde(skip)] + stream_response_channels: Arc< + Mutex< + Option< + HashMap, oneshot::Sender<()>)>>, + >, + >, + >, } const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1); @@ -166,17 +176,28 @@ impl Peer { outgoing_tx, next_message_id: Default::default(), response_channels: Arc::new(Mutex::new(Some(Default::default()))), + stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))), }; let mut writer = MessageStream::new(connection.tx); let mut reader = MessageStream::new(connection.rx); let this = self.clone(); let response_channels = connection_state.response_channels.clone(); + let stream_response_channels = connection_state.stream_response_channels.clone(); + let handle_io = async move { tracing::trace!(%connection_id, "handle io future: start"); let _end_connection = util::defer(|| { response_channels.lock().take(); + if let Some(channels) = stream_response_channels.lock().take() { + for channel in channels.values() { + let _ = channel.unbounded_send(( + Err(anyhow!("connection closed")), + oneshot::channel().0, + )); + } + } this.connections.write().remove(&connection_id); tracing::trace!(%connection_id, "handle io future: end"); }); @@ -268,12 +289,14 @@ impl Peer { }; let response_channels = connection_state.response_channels.clone(); + let stream_response_channels = connection_state.stream_response_channels.clone(); self.connections .write() .insert(connection_id, connection_state); let incoming_rx = incoming_rx.filter_map(move |incoming| { let response_channels = response_channels.clone(); + let stream_response_channels = stream_response_channels.clone(); async move { let message_id = incoming.id; tracing::trace!(?incoming, "incoming message future: start"); @@ -288,8 +311,15 @@ impl Peer { responding_to, "incoming response: received" ); - let channel = response_channels.lock().as_mut()?.remove(&responding_to); - if let Some(tx) = channel { + let response_channel = + response_channels.lock().as_mut()?.remove(&responding_to); + let stream_response_channel = stream_response_channels + .lock() + .as_ref()? + .get(&responding_to) + .cloned(); + + if let Some(tx) = response_channel { let requester_resumed = oneshot::channel(); if let Err(error) = tx.send((incoming, requester_resumed.0)) { tracing::trace!( @@ -314,6 +344,31 @@ impl Peer { responding_to, "incoming response: requester resumed" ); + } else if let Some(tx) = stream_response_channel { + let requester_resumed = oneshot::channel(); + if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) { + tracing::debug!( + %connection_id, + message_id, + responding_to = responding_to, + ?error, + "incoming stream response: request future dropped", + ); + } + + tracing::debug!( + %connection_id, + message_id, + responding_to, + "incoming stream response: waiting to resume requester" + ); + let _ = requester_resumed.1.await; + tracing::debug!( + %connection_id, + message_id, + responding_to, + "incoming stream response: requester resumed" + ); } else { tracing::warn!( %connection_id, @@ -438,6 +493,66 @@ impl Peer { } } + pub fn request_stream( + &self, + receiver_id: ConnectionId, + request: T, + ) -> impl Future>>> { + let (tx, rx) = mpsc::unbounded(); + let send = self.connection_state(receiver_id).and_then(|connection| { + let message_id = connection.next_message_id.fetch_add(1, SeqCst); + let stream_response_channels = connection.stream_response_channels.clone(); + stream_response_channels + .lock() + .as_mut() + .ok_or_else(|| anyhow!("connection was closed"))? + .insert(message_id, tx); + connection + .outgoing_tx + .unbounded_send(proto::Message::Envelope( + request.into_envelope(message_id, None, None), + )) + .map_err(|_| anyhow!("connection was closed"))?; + Ok((message_id, stream_response_channels)) + }); + + async move { + let (message_id, stream_response_channels) = send?; + let stream_response_channels = Arc::downgrade(&stream_response_channels); + + Ok(rx.filter_map(move |(response, _barrier)| { + let stream_response_channels = stream_response_channels.clone(); + future::ready(match response { + Ok(response) => { + if let Some(proto::envelope::Payload::Error(error)) = &response.payload { + Some(Err(anyhow!( + "RPC request {} failed - {}", + T::NAME, + error.message + ))) + } else if let Some(proto::envelope::Payload::EndStream(_)) = + &response.payload + { + // Remove the transmitting end of the response channel to end the stream. + if let Some(channels) = stream_response_channels.upgrade() { + if let Some(channels) = channels.lock().as_mut() { + channels.remove(&message_id); + } + } + None + } else { + Some( + T::Response::from_envelope(response) + .ok_or_else(|| anyhow!("received response of the wrong type")), + ) + } + } + Err(error) => Some(Err(error)), + }) + })) + } + } + pub fn send(&self, receiver_id: ConnectionId, message: T) -> Result<()> { let connection = self.connection_state(receiver_id)?; let message_id = connection @@ -490,6 +605,24 @@ impl Peer { Ok(()) } + pub fn end_stream(&self, receipt: Receipt) -> Result<()> { + let connection = self.connection_state(receipt.sender_id)?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + + let message = proto::EndStream {}; + + connection + .outgoing_tx + .unbounded_send(proto::Message::Envelope(message.into_envelope( + message_id, + Some(receipt.message_id), + None, + )))?; + Ok(()) + } + pub fn respond_with_error( &self, receipt: Receipt, diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 9b885d1840f596..4d67b5a8aa5f34 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -159,6 +159,7 @@ messages!( (DeleteChannel, Foreground), (DeleteNotification, Foreground), (DeleteProjectEntry, Foreground), + (EndStream, Foreground), (Error, Foreground), (ExpandProjectEntry, Foreground), (ExpandProjectEntryResponse, Foreground),