Skip to content

Commit

Permalink
Add streaming responses to the RPC system
Browse files Browse the repository at this point in the history
Co-authored-by: Antonio <antonio@zed.dev>

Resurrected this from some assistant work I did in Spring of 2023.
  • Loading branch information
nathansobo committed Feb 5, 2024
1 parent ac74a72 commit 6208743
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 10 deletions.
35 changes: 31 additions & 4 deletions crates/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use async_tungstenite::tungstenite::{
http::{Request, StatusCode},
};
use futures::{
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
TryFutureExt as _, TryStreamExt,
};
use gpui::{
Expand All @@ -35,7 +35,10 @@ use std::{
future::Future,
marker::PhantomData,
path::PathBuf,
sync::{atomic::AtomicU64, Arc, Weak},
sync::{
atomic::{AtomicU64, Ordering},
Arc, Weak,
},
time::{Duration, Instant},
};
use telemetry::Telemetry;
Expand Down Expand Up @@ -439,15 +442,15 @@ impl Client {
}

pub fn id(&self) -> u64 {
self.id.load(std::sync::atomic::Ordering::SeqCst)
self.id.load(Ordering::SeqCst)
}

pub fn http_client(&self) -> Arc<ZedHttpClient> {
self.http.clone()
}

pub fn set_id(&self, id: u64) -> &Self {
self.id.store(id, std::sync::atomic::Ordering::SeqCst);
self.id.store(id, Ordering::SeqCst);
self
}

Expand Down Expand Up @@ -1256,6 +1259,30 @@ impl Client {
.map_ok(|envelope| envelope.payload)
}

pub fn request_stream<T: RequestMessage>(
&self,
request: T,
) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
let client_id = self.id.load(Ordering::SeqCst);
log::debug!(
"rpc request start. client_id:{}. name:{}",
client_id,
T::NAME
);
let response = self
.connection_id()
.map(|conn_id| self.peer.request_stream(conn_id, request));
async move {
let response = response?.await;
log::debug!(
"rpc request finish. client_id:{}. name:{}",
client_id,
T::NAME
);
response
}
}

pub fn request_envelope<T: RequestMessage>(
&self,
request: T,
Expand Down
32 changes: 32 additions & 0 deletions crates/collab/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod rpc;
#[cfg(test)]
mod tests;

use ::rpc::{proto, ErrorExt};
use axum::{http::StatusCode, response::IntoResponse};
use db::Database;
use executor::Executor;
Expand All @@ -22,6 +23,37 @@ pub enum Error {
Internal(anyhow::Error),
}

impl ErrorExt for Error {
fn error_code(&self) -> proto::ErrorCode {
match self {
Error::Internal(anyhow_error) => anyhow_error.error_code(),
_ => proto::ErrorCode::Internal,
}
}

fn error_tag(&self, k: &str) -> Option<&str> {
match self {
Error::Internal(anyhow_error) => anyhow_error.error_tag(k),
_ => None,
}
}

fn to_proto(&self) -> proto::Error {
match self {
Error::Internal(anyhow_error) => anyhow_error.to_proto(),
_ => proto::Error {
message: self.to_string(),
code: self.error_code() as i32,
tags: Vec::new(),
},
}
}

fn cloned(&self) -> anyhow::Error {
todo!()
}
}

impl From<anyhow::Error> for Error {
fn from(error: anyhow::Error) -> Self {
Self::Internal(error)
Expand Down
52 changes: 52 additions & 0 deletions crates/collab/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ impl<R: RequestMessage> Response<R> {
}
}

struct StreamingResponse<R: RequestMessage> {
peer: Arc<Peer>,
receipt: Receipt<R>,
}

impl<R: RequestMessage> StreamingResponse<R> {
fn send(&self, payload: R::Response) -> Result<()> {
self.peer.respond(self.receipt, payload)?;
Ok(())
}
}

impl<R: RequestMessage> Drop for StreamingResponse<R> {
fn drop(&mut self) {
self.peer.end_stream(self.receipt).trace_err();
}
}

#[derive(Clone)]
struct Session {
zed_environment: Arc<str>,
Expand Down Expand Up @@ -554,6 +572,40 @@ impl Server {
})
}

fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
let handler = Arc::new(handler);
self.add_handler(move |envelope, session| {
let receipt = envelope.receipt();
let handler = handler.clone();
async move {
let peer = session.peer.clone();
let responded = Arc::new(AtomicBool::default());
let response = StreamingResponse {
peer: peer.clone(),
receipt,
};
match (handler)(envelope.payload, response, session).await {
Ok(()) => {
if responded.load(std::sync::atomic::Ordering::SeqCst) {
Ok(())
} else {
Err(anyhow!("handler did not send a response"))?
}
}
Err(error) => {
peer.respond_with_error(receipt, error.to_proto())?;
Err(error)
}
}
}
})
}

pub fn handle_connection(
self: &Arc<Self>,
connection: Connection,
Expand Down
7 changes: 5 additions & 2 deletions crates/rpc/proto/zed.proto
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
syntax = "proto3";
package zed.messages;

// Looking for a number? Search "// Current max"
// Looking for a number? Search "// current max"

message PeerId {
uint32 owner_id = 1;
Expand All @@ -18,6 +18,7 @@ message Envelope {
Error error = 6;
Ping ping = 7;
Test test = 8;
EndStream end_stream = 158; // current max

CreateRoom create_room = 9;
CreateRoomResponse create_room_response = 10;
Expand Down Expand Up @@ -183,7 +184,7 @@ message Envelope {
LspExtExpandMacroResponse lsp_ext_expand_macro_response = 155;
SetRoomParticipantRole set_room_participant_role = 156;

UpdateUserChannels update_user_channels = 157; // current max
UpdateUserChannels update_user_channels = 157;
}
}

Expand Down Expand Up @@ -219,6 +220,8 @@ enum ErrorCode {
UnsharedItem = 12;
}

message EndStream {}

message Test {
uint64 id = 1;
}
Expand Down
Loading

0 comments on commit 6208743

Please sign in to comment.