Skip to content

Commit

Permalink
Add grpc & flmping
Browse files Browse the repository at this point in the history
Signed-off-by: Klaus Ma <klausm@nvidia.com>
  • Loading branch information
k82cn committed Feb 7, 2025
1 parent 97bb0c1 commit 65f9eaa
Show file tree
Hide file tree
Showing 43 changed files with 642 additions and 177 deletions.
166 changes: 118 additions & 48 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ members = [
"session_manager",
"executor_manager",
"rpc",
"sdk/service",
"sdk/client",
"sdk/rust/service",
"sdk/rust/client",
]

[workspace.dependencies]
Expand Down
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ci-image:
sudo docker build -t xflops/flame-console -f docker/Dockerfile.console .

update_protos:
cp rpc/protos/frontend.proto sdk/client/protos
cp rpc/protos/types.proto sdk/client/protos
cp rpc/protos/shim.proto sdk/service/protos
cp rpc/protos/types.proto sdk/service/protos
cp rpc/protos/frontend.proto sdk/rust/client/protos
cp rpc/protos/types.proto sdk/rust/client/protos
cp rpc/protos/shim.proto sdk/rust/service/protos
cp rpc/protos/types.proto sdk/rust/service/protos
46 changes: 40 additions & 6 deletions common/src/apis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub enum Shim {
Stdio = 1,
Wasm = 2,
Shell = 3,
Grpc = 4,
}

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -153,15 +154,15 @@ pub struct Executor {

#[derive(Clone, Debug)]
pub struct TaskContext {
pub id: String,
pub ssn_id: String,
pub task_id: String,
pub session_id: String,
pub input: Option<TaskInput>,
pub output: Option<TaskOutput>,
}

#[derive(Clone, Debug)]
pub struct SessionContext {
pub ssn_id: String,
pub session_id: String,
pub application: ApplicationContext,
pub slots: i32,
pub common_data: Option<CommonData>,
Expand Down Expand Up @@ -243,8 +244,8 @@ impl TryFrom<rpc::Task> for TaskContext {
.ok_or(FlameError::InvalidConfig("spec".to_string()))?;

Ok(TaskContext {
id: metadata.id,
ssn_id: spec.session_id.to_string(),
task_id: metadata.id.clone(),
session_id: spec.session_id.to_string(),
input: spec.input.map(TaskInput::from),
output: spec.output.map(TaskOutput::from),
})
Expand Down Expand Up @@ -273,6 +274,37 @@ impl TryFrom<rpc::Application> for ApplicationContext {
}
}

impl From<TaskContext> for rpc::TaskContext {
fn from(ctx: TaskContext) -> Self {
Self {
task_id: ctx.task_id.clone(),
session_id: ctx.session_id.clone(),
input: ctx.input.map(|d| d.into()),
}
}
}

impl From<SessionContext> for rpc::SessionContext {
fn from(ctx: SessionContext) -> Self {
Self {
session_id: ctx.session_id.clone(),
application: Some(ctx.application.into()),
common_data: ctx.common_data.map(|d| d.into()),
}
}
}

impl From<ApplicationContext> for rpc::ApplicationContext {
fn from(ctx: ApplicationContext) -> Self {
Self {
name: ctx.name.clone(),
url: ctx.url.clone(),
shim: ctx.shim.into(),
command: ctx.command.clone(),
}
}
}

impl TryFrom<rpc::BindExecutorResponse> for SessionContext {
type Error = FlameError;

Expand All @@ -295,7 +327,7 @@ impl TryFrom<rpc::BindExecutorResponse> for SessionContext {
let application = ApplicationContext::try_from(app)?;

Ok(SessionContext {
ssn_id: metadata.id,
session_id: metadata.id,
application,
slots: spec.slots,
common_data: spec.common_data.map(CommonData::from),
Expand Down Expand Up @@ -559,6 +591,7 @@ impl From<rpc::Shim> for Shim {
rpc::Shim::Stdio => Self::Stdio,
rpc::Shim::Wasm => Self::Wasm,
rpc::Shim::Shell => Self::Shell,
rpc::Shim::Grpc => Self::Grpc,
}
}
}
Expand All @@ -570,6 +603,7 @@ impl From<Shim> for rpc::Shim {
Shim::Stdio => Self::Stdio,
Shim::Wasm => Self::Wasm,
Shim::Shell => Self::Shell,
Shim::Grpc => Self::Grpc,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions executor_manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ log = { workspace = true }
async-trait = { workspace = true }
clap = { workspace = true }
prost = { workspace = true }
tower = "0.5"
hyper-util = "0.1"

bytes = "1"
chrono = "0.4"
Expand Down
88 changes: 80 additions & 8 deletions executor_manager/src/shims/grpc_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,114 @@ limitations under the License.
*/

use std::sync::Arc;
use std::{thread, time};

use async_trait::async_trait;
use hyper_util::rt::TokioIo;
use tokio::net::UnixStream;
use tokio::sync::Mutex;
use tonic::transport::Channel;
use tonic::transport::{Endpoint, Uri};
use tonic::Request;
use tower::service_fn;

use ::rpc::flame as rpc;
use rpc::grpc_shim_client::GrpcShimClient;
use rpc::EmptyRequest;

use crate::shims::{Shim, ShimPtr};
use common::apis::{ApplicationContext, SessionContext, TaskContext, TaskOutput};
use common::FlameError;

#[derive(Clone)]
pub struct GrpcShim {
session_context: Option<SessionContext>,
client: GrpcShimClient<Channel>,
child: tokio::process::Child,
}

const FLAME_SOCKET_PATH: &str = "FLAME_SOCKET_PATH";

impl GrpcShim {
pub fn new_ptr(_: &ApplicationContext) -> ShimPtr {
// TODO: launch service based on application context.
Arc::new(Mutex::new(Self {
pub async fn new_ptr(app_ctx: &ApplicationContext) -> Result<ShimPtr, FlameError> {
let socket_path = format!("/tmp/flame-shim-{}.sock", uuid::Uuid::new_v4().simple());
std::env::set_var(FLAME_SOCKET_PATH, socket_path.clone());

// Spawn child process
let mut cmd = tokio::process::Command::new(&app_ctx.command.clone().unwrap());
cmd.env(FLAME_SOCKET_PATH, &socket_path).kill_on_drop(true);

let child = cmd
.env(FLAME_SOCKET_PATH, &socket_path)
.spawn()
.map_err(|e| FlameError::InvalidConfig(e.to_string()))?;

let channel = Endpoint::try_from("http://[::]:50051")
.map_err(|e| FlameError::Network(e.to_string()))?
.connect_with_connector(service_fn(|_: Uri| async {
// Connect to a Uds socket
let path = std::env::var(FLAME_SOCKET_PATH).ok().unwrap();
Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path).await?))
}))
.await
.map_err(|e| FlameError::Network(e.to_string()))?;

let mut client = GrpcShimClient::new(channel);

let mut connected = false;
for i in 1..10 {
let resp = client
.readiness(Request::new(EmptyRequest::default()))
.await;
if resp.is_ok() {
connected = true;
break;
}
// sleep 1s
let ten_millis = time::Duration::from_secs(1);
thread::sleep(ten_millis);
}

if !connected {
return Err(FlameError::InvalidConfig(
"failed to connect to service".to_string(),
));
}

Ok(Arc::new(Mutex::new(Self {
session_context: None,
}))
client,
child,
})))
}
}

#[async_trait]
impl Shim for GrpcShim {
async fn on_session_enter(&mut self, ctx: &SessionContext) -> Result<(), FlameError> {
todo!()
let req = Request::new(rpc::SessionContext::from(ctx.clone()));
self.client.on_session_enter(req).await?;
Ok(())
}

async fn on_task_invoke(
&mut self,
ctx: &TaskContext,
) -> Result<Option<TaskOutput>, FlameError> {
todo!()
let req = Request::new(rpc::TaskContext::from(ctx.clone()));
let resp = self.client.on_task_invoke(req).await?;
let output = resp.into_inner();

Ok(output.data.map(|d| d.into()))
}

async fn on_session_leave(&mut self) -> Result<(), FlameError> {
todo!()
let _ = self
.client
.on_session_leave(Request::new(EmptyRequest::default()))
.await?;

self.child.kill();

Ok(())
}
}
8 changes: 4 additions & 4 deletions executor_manager/src/shims/log_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Shim for LogShim {
async fn on_session_enter(&mut self, ctx: &SessionContext) -> Result<(), FlameError> {
log::info!(
"on_session_enter: Session: <{}>, Application: <{}>, Slots: <{}>",
ctx.ssn_id,
ctx.session_id,
ctx.application.name,
ctx.slots
);
Expand All @@ -53,8 +53,8 @@ impl Shim for LogShim {
) -> Result<Option<TaskOutput>, FlameError> {
log::info!(
"on_task_invoke: Task: <{}>, Session: <{}>",
ctx.id,
ctx.ssn_id
ctx.task_id,
ctx.session_id
);
Ok(None)
}
Expand All @@ -67,7 +67,7 @@ impl Shim for LogShim {
Some(ctx) => {
log::info!(
"on_session_leave: Session: <{}>, Application: <{}>, Slots: <{}>",
ctx.ssn_id,
ctx.session_id,
ctx.application.name,
ctx.slots
);
Expand Down
8 changes: 5 additions & 3 deletions executor_manager/src/shims/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

mod grpc_shim;
mod log_shim;
mod shell_shim;
mod stdio_shim;
mod wasm_shim;
mod shell_shim;
mod grpc_shim;

use std::sync::Arc;

use async_trait::async_trait;
use grpc_shim::GrpcShim;
use tokio::sync::Mutex;

use self::log_shim::LogShim;
use self::shell_shim::ShellShim;
use self::stdio_shim::StdioShim;
use self::wasm_shim::WasmShim;
use self::shell_shim::ShellShim;

use common::apis::{ApplicationContext, SessionContext, Shim as ShimType, TaskContext, TaskOutput};

Expand All @@ -38,6 +39,7 @@ pub async fn from(app: &ApplicationContext) -> Result<ShimPtr, FlameError> {
ShimType::Stdio => Ok(StdioShim::new_ptr(app)),
ShimType::Wasm => Ok(WasmShim::new_ptr(app).await?),
ShimType::Shell => Ok(ShellShim::new_ptr(app)),
ShimType::Grpc => Ok(GrpcShim::new_ptr(app).await?),
_ => Ok(LogShim::new_ptr(app)),
}
}
Expand Down
13 changes: 8 additions & 5 deletions executor_manager/src/shims/shell_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ impl Shim for ShellShim {
&mut self,
ctx: &TaskContext,
) -> Result<Option<TaskOutput>, FlameError> {
let input = ctx.input.clone().ok_or(FlameError::Uninitialized(String::from(
"task input is empty",
)))?;
let input = ctx
.input
.clone()
.ok_or(FlameError::Uninitialized(String::from(
"task input is empty",
)))?;
let mut cmd = String::from_utf8(input.to_ascii_lowercase())
.map_err(|e| FlameError::Uninitialized(format!("task input is invalid: {}", e)))?;

Expand Down Expand Up @@ -86,8 +89,8 @@ impl Shim for ShellShim {
.stdout(Stdio::piped())
// TODO: add working dir
// .current_dir(&self.application.working_directory)
.env(FLAME_TASK_ID, &ctx.id)
.env(FLAME_SESSION_ID, &ctx.ssn_id)
.env(FLAME_TASK_ID, &ctx.task_id)
.env(FLAME_SESSION_ID, &ctx.session_id)
.spawn()
.map_err(|_| FlameError::Internal("failed to start subprocess".to_string()))?;

Expand Down
4 changes: 2 additions & 2 deletions executor_manager/src/shims/stdio_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ impl Shim for StdioShim {
.stdout(Stdio::piped())
// TODO: add working dir
// .current_dir(&self.application.working_directory)
.env(FLAME_TASK_ID, &ctx.id)
.env(FLAME_SESSION_ID, &ctx.ssn_id)
.env(FLAME_TASK_ID, &ctx.task_id)
.env(FLAME_SESSION_ID, &ctx.session_id)
.spawn()
.map_err(|_| FlameError::Internal("failed to start subprocess".to_string()))?;

Expand Down
8 changes: 4 additions & 4 deletions executor_manager/src/shims/wasm_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl Shim for WasmShim {
ctx: &apis::SessionContext,
) -> Result<(), common::FlameError> {
let ssn_ctx = service::SessionContext {
session_id: ctx.ssn_id.clone(),
session_id: ctx.session_id.clone(),
common_data: ctx.common_data.clone().map(apis::CommonData::into),
};

Expand All @@ -101,8 +101,8 @@ impl Shim for WasmShim {
ctx: &apis::TaskContext,
) -> Result<Option<apis::TaskOutput>, common::FlameError> {
let task_ctx = service::TaskContext {
session_id: ctx.ssn_id.clone(),
task_id: ctx.id.clone(),
session_id: ctx.session_id.clone(),
task_id: ctx.task_id.clone(),
};

let output = self
Expand All @@ -122,7 +122,7 @@ impl Shim for WasmShim {

async fn on_session_leave(&mut self) -> Result<(), common::FlameError> {
let ssn_ctx = service::SessionContext {
session_id: self.session_context.clone().unwrap().ssn_id.clone(),
session_id: self.session_context.clone().unwrap().session_id.clone(),
common_data: None,
};

Expand Down
2 changes: 1 addition & 1 deletion executor_manager/src/states/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl State for BoundState {

let (ssn_id, task_id) = {
let task = &self.executor.task.clone().unwrap();
(task.ssn_id.clone(), task.id.clone())
(task.session_id.clone(), task.task_id.clone())
};
log::debug!("Complete task <{}/{}>", ssn_id, task_id)
}
Expand Down
Loading

0 comments on commit 65f9eaa

Please sign in to comment.