diff --git a/Cargo.lock b/Cargo.lock index 03748cc86..dfdb35805 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -936,11 +936,13 @@ dependencies = [ "tempfile", "thiserror 2.0.11", "tokio", + "tokio-async-drop", "tokio-stream", "toml", "tracing", "tracing-opentelemetry", "tracing-subscriber", + "trait-variant", "ttrpc-codegen", "wasmparser 0.226.0", "wat", diff --git a/crates/containerd-shim-wasm/Cargo.toml b/crates/containerd-shim-wasm/Cargo.toml index 61e191639..f09e21601 100644 --- a/crates/containerd-shim-wasm/Cargo.toml +++ b/crates/containerd-shim-wasm/Cargo.toml @@ -36,6 +36,8 @@ sha256 = { workspace = true } serde_bytes = "0.11" prost = "0.13" toml = "0.8" +trait-variant = "0.1" +tokio-async-drop = "0.1" # tracing # note: it's important to keep the version of tracing in sync with tracing-subscriber diff --git a/crates/containerd-shim-wasm/src/sandbox/instance.rs b/crates/containerd-shim-wasm/src/sandbox/instance.rs index ec8c0a681..2d34ecf38 100644 --- a/crates/containerd-shim-wasm/src/sandbox/instance.rs +++ b/crates/containerd-shim-wasm/src/sandbox/instance.rs @@ -32,28 +32,29 @@ pub struct InstanceConfig { /// Instance is a trait that gets implemented by consumers of this library. /// This trait requires that any type implementing it is `'static`, similar to `std::any::Any`. /// This means that the type cannot contain a non-`'static` reference. +#[trait_variant::make(Send)] pub trait Instance: 'static { /// The WASI engine type type Engine: Send + Sync + Clone; /// Create a new instance - fn new(id: String, cfg: &InstanceConfig) -> Result + async fn new(id: String, cfg: &InstanceConfig) -> Result where Self: Sized; /// Start the instance /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. - fn start(&self) -> Result; + async fn start(&self) -> Result; /// Send a signal to the instance - fn kill(&self, signal: u32) -> Result<(), Error>; + async fn kill(&self, signal: u32) -> Result<(), Error>; /// Delete any reference to the instance /// This is called after the instance has exited. - fn delete(&self) -> Result<(), Error>; + async fn delete(&self) -> Result<(), Error>; /// Waits for the instance to finish and returns its exit code /// This is an async call. - fn wait(&self) -> impl Future)> + Send; + async fn wait(&self) -> (u32, DateTime); } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs index a85aeba7a..849f9480c 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs @@ -1,25 +1,26 @@ use std::env::current_dir; use std::fmt::Debug; -use std::sync::Arc; use chrono::Utc; use containerd_shim::error::Error as ShimError; use containerd_shim::publisher::RemotePublisher; use containerd_shim::util::write_address; -use containerd_shim::{self as shim, ExitSignal, api}; +use containerd_shim::{self as shim, api}; use oci_spec::runtime::Spec; use shim::Flags; +use crate::sandbox::async_utils::AmbientRuntime as _; use crate::sandbox::instance::Instance; use crate::sandbox::shim::events::{RemoteEventSender, ToTimestamp}; use crate::sandbox::shim::local::Local; +use crate::sandbox::sync::WaitableCell; /// Cli implements the containerd-shim cli interface using `Local` as the task service. pub struct Cli { engine: T::Engine, namespace: String, containerd_address: String, - exit: Arc, + exit: WaitableCell<()>, _id: String, } @@ -50,7 +51,7 @@ where engine: Default::default(), namespace: args.namespace.to_string(), containerd_address: args.address.clone(), - exit: Arc::default(), + exit: WaitableCell::new(), _id: args.id.to_string(), } } @@ -78,7 +79,7 @@ where #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] fn wait(&mut self) { - self.exit.wait(); + self.exit.wait().block_on(); } #[cfg_attr( diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs b/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs index 46c973af3..c035d8502 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs @@ -14,9 +14,12 @@ pub(super) struct InstanceData { impl InstanceData { #[cfg_attr(feature = "tracing", tracing::instrument(level = "Debug"))] - pub fn new(id: impl AsRef + std::fmt::Debug, config: InstanceConfig) -> Result { + pub async fn new( + id: impl AsRef + std::fmt::Debug, + config: InstanceConfig, + ) -> Result { let id = id.as_ref().to_string(); - let instance = T::new(id, &config)?; + let instance = T::new(id, &config).await?; Ok(Self { instance, config, @@ -31,11 +34,11 @@ impl InstanceData { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn start(&self) -> Result { + pub async fn start(&self) -> Result { let mut s = self.state.write().unwrap(); s.start()?; - let res = self.instance.start(); + let res = self.instance.start().await; // These state transitions are always `Ok(())` because // we hold the lock since `s.start()` @@ -51,19 +54,19 @@ impl InstanceData { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn kill(&self, signal: u32) -> Result<()> { + pub async fn kill(&self, signal: u32) -> Result<()> { let mut s = self.state.write().unwrap(); s.kill()?; - self.instance.kill(signal) + self.instance.kill(signal).await } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn delete(&self) -> Result<()> { + pub async fn delete(&self) -> Result<()> { let mut s = self.state.write().unwrap(); s.delete()?; - let res = self.instance.delete(); + let res = self.instance.delete().await; if res.is_err() { // Always `Ok(())` because we hold the lock since `s.delete()` diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs index 0682ea3d1..e121fc7d6 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs @@ -2,10 +2,7 @@ use std::collections::HashMap; use std::fs::create_dir_all; use std::ops::Not; use std::path::Path; -use std::sync::{Arc, RwLock}; -use std::thread; -#[cfg(feature = "opentelemetry")] -use std::time::Duration; +use std::sync::Arc; use anyhow::ensure; use containerd_shim::api::{ @@ -18,13 +15,14 @@ use containerd_shim::protos::events::task::{TaskCreate, TaskDelete, TaskExit, Ta use containerd_shim::protos::shim::shim_ttrpc::Task; use containerd_shim::protos::types::task::Status; use containerd_shim::util::IntoOption; -use containerd_shim::{DeleteResponse, ExitSignal, TtrpcContext, TtrpcResult}; +use containerd_shim::{DeleteResponse, TtrpcContext, TtrpcResult}; use futures::FutureExt as _; use log::debug; use oci_spec::runtime::Spec; use prost::Message; use protobuf::well_known_types::any::Any; use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; #[cfg(feature = "opentelemetry")] use tracing_opentelemetry::OpenTelemetrySpanExt as _; @@ -34,6 +32,7 @@ use crate::sandbox::async_utils::AmbientRuntime as _; use crate::sandbox::instance::{Instance, InstanceConfig}; use crate::sandbox::shim::events::{EventSender, RemoteEventSender, ToTimestamp}; use crate::sandbox::shim::instance_data::InstanceData; +use crate::sandbox::sync::WaitableCell; use crate::sandbox::{Error, Result, oci}; use crate::sys::metrics::get_metrics; @@ -90,7 +89,7 @@ pub struct Local pub engine: T::Engine, pub(super) instances: LocalInstances, events: E, - exit: Arc, + exit: WaitableCell<()>, namespace: String, containerd_address: String, } @@ -104,7 +103,7 @@ impl Local { pub fn new( engine: T::Engine, events: E, - exit: Arc, + exit: WaitableCell<()>, namespace: impl AsRef + std::fmt::Debug, containerd_address: impl AsRef + std::fmt::Debug, ) -> Self { @@ -122,26 +121,26 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub(super) fn get_instance(&self, id: &str) -> Result>> { - let instance = self.instances.read().unwrap().get(id).cloned(); + pub(super) async fn get_instance(&self, id: &str) -> Result>> { + let instance = self.instances.read().await.get(id).cloned(); instance.ok_or_else(|| Error::NotFound(id.to_string())) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn has_instance(&self, id: &str) -> bool { - self.instances.read().unwrap().contains_key(id) + async fn has_instance(&self, id: &str) -> bool { + self.instances.read().await.contains_key(id) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn is_empty(&self) -> bool { - self.instances.read().unwrap().is_empty() + async fn is_empty(&self) -> bool { + self.instances.read().await.is_empty() } } // These are the same functions as in Task, but without the TtrcpContext, which is useful for testing impl Local { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_create(&self, req: CreateTaskRequest) -> Result { + async fn task_create(&self, req: CreateTaskRequest) -> Result { let config = Config::get_from_options(req.options.as_ref()) .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?; @@ -155,7 +154,7 @@ impl Local { )); } - if self.has_instance(&req.id) { + if self.has_instance(&req.id).await { return Err(Error::AlreadyExists(req.id)); } @@ -200,11 +199,11 @@ impl Local { }; // Check if this is a cri container - let instance = InstanceData::new(req.id(), cfg)?; + let instance = InstanceData::new(req.id(), cfg).await?; self.instances .write() - .unwrap() + .await .insert(req.id().to_string(), Arc::new(instance)); self.events.send(TaskCreate { @@ -234,13 +233,13 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_start(&self, req: StartRequest) -> Result { + async fn task_start(&self, req: StartRequest) -> Result { if req.exec_id().is_empty().not() { return Err(ShimError::Unimplemented("exec is not supported".to_string()).into()); } - let i = self.get_instance(req.id())?; - let pid = i.start()?; + let i = self.get_instance(req.id()).await?; + let pid = i.start().await?; self.events.send(TaskStart { container_id: req.id().into(), @@ -274,29 +273,29 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_kill(&self, req: KillRequest) -> Result { + async fn task_kill(&self, req: KillRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - self.get_instance(req.id())?.kill(req.signal())?; + self.get_instance(req.id()).await?.kill(req.signal()).await?; Ok(Empty::new()) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_delete(&self, req: DeleteRequest) -> Result { + async fn task_delete(&self, req: DeleteRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).await?; - i.delete()?; + i.delete().await?; let pid = i.pid().unwrap_or_default(); let (exit_code, timestamp) = i.wait().now_or_never().unzip(); let timestamp = timestamp.map(ToTimestamp::to_timestamp); - self.instances.write().unwrap().remove(req.id()); + self.instances.write().await.remove(req.id()); self.events.send(TaskDelete { container_id: req.id().into(), @@ -315,13 +314,13 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_wait(&self, req: WaitRequest) -> Result { + async fn task_wait(&self, req: WaitRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; - let (exit_code, timestamp) = i.wait().block_on(); + let i = self.get_instance(req.id()).await?; + let (exit_code, timestamp) = i.wait().await; debug!("wait finishes"); Ok(WaitResponse { @@ -332,12 +331,12 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_state(&self, req: StateRequest) -> Result { + async fn task_state(&self, req: StateRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).await?; let pid = i.pid(); let (exit_code, timestamp) = i.wait().now_or_never().unzip(); let timestamp = timestamp.map(ToTimestamp::to_timestamp); @@ -364,8 +363,8 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_stats(&self, req: StatsRequest) -> Result { - let i = self.get_instance(req.id())?; + async fn task_stats(&self, req: StatsRequest) -> Result { + let i = self.get_instance(req.id()).await?; let pid = i .pid() .ok_or_else(|| Error::InvalidArgument("task is not running".to_string()))?; @@ -391,7 +390,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_create(req)?) + Ok(self.task_create(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -401,7 +400,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_start(req)?) + Ok(self.task_start(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -411,7 +410,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_kill(req)?) + Ok(self.task_kill(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -421,7 +420,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_delete(req)?) + Ok(self.task_delete(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -429,33 +428,33 @@ impl Task for Local { debug!("wait: {:?}", req); #[cfg(feature = "opentelemetry")] - { + let span_exporter = { use tracing::{Level, Span, span}; let parent_span = Span::current(); parent_span.set_parent(extract_context(&_ctx.metadata)); - let (tx, rx) = std::sync::mpsc::channel(); - // Start a thread to export interval span for long wait - - let _ = thread::spawn(move || { + async move { loop { let current_span = span!(parent: &parent_span, Level::INFO, "task wait 60s interval"); let _enter = current_span.enter(); - if rx.recv_timeout(Duration::from_secs(60)).is_ok() { - break; - } + tokio::time::sleep(std::time::Duration::from_secs(60)).await; } - }); - let result = self.task_wait(req)?; - tx.send(()).unwrap(); - Ok(result) - } + } + }; #[cfg(not(feature = "opentelemetry"))] - { - Ok(self.task_wait(req)?) + let span_exporter = std::future::pending::<()>(); + + let res = async { + tokio::select! { + _ = span_exporter => unreachable!(), + res = self.task_wait(req) => res, + } } + .block_on()?; + + Ok(res) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -465,7 +464,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).block_on()?; let shim_pid = std::process::id(); let task_pid = i.pid().unwrap_or_default(); Ok(ConnectResponse { @@ -482,7 +481,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_state(req)?) + Ok(self.task_state(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -492,8 +491,8 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - if self.is_empty() { - self.exit.signal(); + if self.is_empty().block_on() { + let _ = self.exit.set(()); } Ok(Empty::new()) } @@ -505,6 +504,6 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_stats(req)?) + Ok(self.task_stats(req).block_on()?) } } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs b/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs index 99ed5e3c2..f3a898c53 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs @@ -1,6 +1,4 @@ use std::fs::{File, create_dir}; -use std::sync::mpsc::{Sender, channel}; -use std::thread; use std::time::Duration; use anyhow::Context; @@ -10,6 +8,8 @@ use containerd_shim::event::Event; use protobuf::{MessageDyn, SpecialFields}; use serde_json as json; use tempfile::tempdir; +use tokio::sync::mpsc::{UnboundedSender as Sender, unbounded_channel as channel}; +use tokio_async_drop::tokio_async_drop; use super::*; use crate::sandbox::shim::events::EventSender; @@ -24,19 +24,19 @@ pub struct InstanceStub { impl Instance for InstanceStub { type Engine = (); - fn new(_id: String, _cfg: &InstanceConfig) -> Result { + async fn new(_id: String, _cfg: &InstanceConfig) -> Result { Ok(InstanceStub { exit_code: WaitableCell::new(), }) } - fn start(&self) -> Result { + async fn start(&self) -> Result { Ok(std::process::id()) } - fn kill(&self, _signal: u32) -> Result<(), Error> { + async fn kill(&self, _signal: u32) -> Result<(), Error> { let _ = self.exit_code.set((1, Utc::now())); Ok(()) } - fn delete(&self) -> Result<(), Error> { + async fn delete(&self) -> Result<(), Error> { Ok(()) } async fn wait(&self) -> (u32, DateTime) { @@ -62,15 +62,13 @@ impl EventSender for Sender<(String, Box)> { impl Drop for LocalWithDestructor { fn drop(&mut self) { - self.local - .instances - .write() - .unwrap() - .iter() - .for_each(|(_, v)| { - let _ = v.kill(9); - v.delete().unwrap(); - }); + tokio_async_drop!({ + let instances = self.local.instances.write().await; + for (_, instance) in instances.iter() { + let _ = instance.kill(9).await; + let _ = instance.delete().await; + } + }) } } @@ -98,8 +96,8 @@ fn create_bundle(dir: &std::path::Path, spec: Option) -> Result<()> { Ok(()) } -#[test] -fn test_delete_after_create() { +#[tokio::test] +async fn test_delete_after_create() -> anyhow::Result<()> { let dir = tempdir().unwrap(); let id = "test-delete-after-create"; create_bundle(dir.path(), None).unwrap(); @@ -108,7 +106,7 @@ fn test_delete_after_create() { let local = Arc::new(Local::::new( (), tx, - Arc::new(ExitSignal::default()), + WaitableCell::new(), "test_namespace", "/test/address", )); @@ -120,22 +118,24 @@ fn test_delete_after_create() { bundle: dir.path().to_str().unwrap().to_string(), ..Default::default() }) - .unwrap(); + .await?; local .task_delete(DeleteRequest { id: id.to_string(), ..Default::default() }) - .unwrap(); + .await?; + + Ok(()) } -#[test] -fn test_cri_task() -> Result<()> { +#[tokio::test] +async fn test_cri_task() -> Result<()> { // Currently the relationship between the "base" container and the "instances" are pretty weak. // When a cri sandbox is specified we just assume it's the sandbox container and treat it as such by not actually running the code (which is going to be wasm). let (etx, _erx) = channel(); - let exit_signal = Arc::new(ExitSignal::default()); + let exit_signal = WaitableCell::new(); let local = Arc::new(Local::::new( (), etx, @@ -151,39 +151,49 @@ fn test_cri_task() -> Result<()> { let sandbox_id = "test-cri-task".to_string(); create_bundle(dir, Some(with_cri_sandbox(None, sandbox_id.clone())))?; - local.task_create(CreateTaskRequest { - id: "testbase".to_string(), - bundle: dir.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "testbase".to_string(), + bundle: dir.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); // make sure that the instance exists - let _i = local.get_instance("testbase")?; + let _i = local.get_instance("testbase").await?; - local.task_start(StartRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); let ll = local.clone(); - let (base_tx, base_rx) = channel(); - thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "testbase".to_string(), - ..Default::default() - }); + let (base_tx, mut base_rx) = channel(); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await; base_tx.send(resp).unwrap(); }); base_rx.try_recv().unwrap_err(); @@ -192,72 +202,96 @@ fn test_cri_task() -> Result<()> { let dir2 = temp2.path(); create_bundle(dir2, Some(with_cri_sandbox(None, sandbox_id)))?; - local.task_create(CreateTaskRequest { - id: "testinstance".to_string(), - bundle: dir2.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "testinstance".to_string(), + bundle: dir2.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); // make sure that the instance exists - let _i = local.get_instance("testinstance")?; + let _i = local.get_instance("testinstance").await?; - local.task_start(StartRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - let stats = local.task_stats(StatsRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let stats = local + .task_stats(StatsRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert!(stats.has_stats()); let ll = local.clone(); - let (instance_tx, instance_rx) = channel(); - std::thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "testinstance".to_string(), - ..Default::default() - }); + let (instance_tx, mut instance_rx) = channel(); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await; instance_tx.send(resp).unwrap(); }); instance_rx.try_recv().unwrap_err(); - local.task_kill(KillRequest { - id: "testinstance".to_string(), - signal: 9, - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "testinstance".to_string(), + signal: 9, + ..Default::default() + }) + .await?; - instance_rx.recv_timeout(Duration::from_secs(50)).unwrap()?; + instance_rx + .recv() + .with_timeout(Duration::from_secs(50)) + .await + .flatten() + .unwrap()?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "testinstance".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} @@ -265,34 +299,48 @@ fn test_cri_task() -> Result<()> { } base_rx.try_recv().unwrap_err(); - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - local.task_kill(KillRequest { - id: "testbase".to_string(), - signal: 9, - ..Default::default() - })?; - - base_rx.recv_timeout(Duration::from_secs(5)).unwrap()?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "testbase".to_string(), + signal: 9, + ..Default::default() + }) + .await?; + + base_rx + .recv() + .with_timeout(Duration::from_secs(5)) + .await + .flatten() + .unwrap()?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "testbase".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} @@ -302,10 +350,10 @@ fn test_cri_task() -> Result<()> { Ok(()) } -#[test] -fn test_task_lifecycle() -> Result<()> { +#[tokio::test] +async fn test_task_lifecycle() -> Result<()> { let (etx, _erx) = channel(); // TODO: check events - let exit_signal = Arc::new(ExitSignal::default()); + let exit_signal = WaitableCell::new(); let local = Arc::new(Local::::new( (), etx, @@ -325,17 +373,20 @@ fn test_task_lifecycle() -> Result<()> { id: "test".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} e => return Err(e), } - local.task_create(CreateTaskRequest { - id: "test".to_string(), - bundle: dir.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "test".to_string(), + bundle: dir.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; match local .task_create(CreateTaskRequest { @@ -343,73 +394,95 @@ fn test_task_lifecycle() -> Result<()> { bundle: dir.to_str().unwrap().to_string(), ..Default::default() }) + .await .unwrap_err() { Error::AlreadyExists(_) => {} e => return Err(e), } - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); - local.task_start(StartRequest { - id: "test".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - let (tx, rx) = channel(); + let (tx, mut rx) = channel(); let ll = local.clone(); - thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "test".to_string(), - ..Default::default() - }); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "test".to_string(), + ..Default::default() + }) + .await; tx.send(resp).unwrap(); }); rx.try_recv().unwrap_err(); - let res = local.task_stats(StatsRequest { - id: "test".to_string(), - ..Default::default() - })?; + let res = local + .task_stats(StatsRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert!(res.has_stats()); - local.task_kill(KillRequest { - id: "test".to_string(), - signal: 9, - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "test".to_string(), + signal: 9, + ..Default::default() + }) + .await?; - rx.recv_timeout(Duration::from_secs(5)).unwrap()?; + rx.recv() + .with_timeout(Duration::from_secs(5)) + .await + .flatten() + .unwrap()?; - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "test".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "test".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} diff --git a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs index 47f64f757..43284c75b 100644 --- a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs @@ -32,11 +32,11 @@ impl SandboxInstance for Instance { type Engine = E; #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn new(id: String, cfg: &InstanceConfig) -> Result { + async fn new(id: String, cfg: &InstanceConfig) -> Result { // check if container is OCI image with wasm layers and attempt to read the module - let (modules, platform) = containerd::Client::connect(&cfg.containerd_address, &cfg.namespace).block_on()? + let (modules, platform) = containerd::Client::connect(&cfg.containerd_address, &cfg.namespace).await? .load_modules(&id, &E::default()) - .block_on() + .await .unwrap_or_else(|e| { log::warn!("Error obtaining wasm layers for container {id}. Will attempt to use files inside container image. Error: {e}"); (vec![], Platform::default()) @@ -85,7 +85,7 @@ impl SandboxInstance for Instance { /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn start(&self) -> Result { + async fn start(&self) -> Result { log::info!("starting instance: {}", self.id); // make sure we have an exit code by the time we finish (even if there's a panic) let guard = self.exit_code.clone().set_guard_with(|| (137, Utc::now())); @@ -125,7 +125,7 @@ impl SandboxInstance for Instance { /// Send a signal to the instance #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn kill(&self, signal: u32) -> Result<(), SandboxError> { + async fn kill(&self, signal: u32) -> Result<(), SandboxError> { log::info!("sending signal {signal} to instance: {}", self.id); self.container.kill(signal)?; Ok(()) @@ -134,7 +134,7 @@ impl SandboxInstance for Instance { /// Delete any reference to the instance /// This is called after the instance has exited. #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn delete(&self) -> Result<(), SandboxError> { + async fn delete(&self) -> Result<(), SandboxError> { log::info!("deleting instance: {}", self.id); self.container.delete()?; Ok(()) diff --git a/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs b/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs index 486009923..54efcd965 100644 --- a/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs @@ -11,25 +11,25 @@ pub struct Instance(PhantomData); impl SandboxInstance for Instance { type Engine = E; - fn new(_id: String, _cfg: &InstanceConfig) -> Result { + async fn new(_id: String, _cfg: &InstanceConfig) -> Result { todo!(); } /// Start the instance /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. - fn start(&self) -> Result { + async fn start(&self) -> Result { todo!(); } /// Send a signal to the instance - fn kill(&self, _signal: u32) -> Result<(), SandboxError> { + async fn kill(&self, _signal: u32) -> Result<(), SandboxError> { todo!(); } /// Delete any reference to the instance /// This is called after the instance has exited. - fn delete(&self) -> Result<(), SandboxError> { + async fn delete(&self) -> Result<(), SandboxError> { todo!(); } diff --git a/crates/containerd-shim-wasm/src/testing.rs b/crates/containerd-shim-wasm/src/testing.rs index 00fa93b77..d2c5a659b 100644 --- a/crates/containerd-shim-wasm/src/testing.rs +++ b/crates/containerd-shim-wasm/src/testing.rs @@ -225,7 +225,7 @@ where ..Default::default() }; - let instance = WasiInstance::new(self.container_name, &cfg)?; + let instance = WasiInstance::new(self.container_name, &cfg).block_on()?; Ok(WasiTest { instance, tempdir }) } } @@ -244,7 +244,7 @@ where pub fn start(&self) -> Result<&Self> { log::info!("starting wasi test"); - let pid = self.instance.start()?; + let pid = self.instance.start().block_on()?; log::info!("wasi test pid {pid}"); Ok(self) @@ -252,25 +252,25 @@ where pub fn delete(&self) -> Result<&Self> { log::info!("deleting wasi test"); - self.instance.delete()?; + self.instance.delete().block_on()?; Ok(self) } pub fn ctrl_c(&self) -> Result<&Self> { log::info!("sending SIGINT"); - self.instance.kill(SIGINT as u32)?; + self.instance.kill(SIGINT as u32).block_on()?; Ok(self) } pub fn terminate(&self) -> Result<&Self> { log::info!("sending SIGTERM"); - self.instance.kill(SIGTERM as u32)?; + self.instance.kill(SIGTERM as u32).block_on()?; Ok(self) } pub fn kill(&self) -> Result<&Self> { log::info!("sending SIGKILL"); - self.instance.kill(SIGKILL as u32)?; + self.instance.kill(SIGKILL as u32).block_on()?; Ok(self) } @@ -279,7 +279,7 @@ where let (status, _) = match self.instance.wait().with_timeout(t).block_on() { Some(res) => res, None => { - self.instance.kill(SIGKILL)?; + self.instance.kill(SIGKILL).block_on()?; bail!("timeout while waiting for module to finish"); } }; @@ -287,7 +287,7 @@ where let stdout = self.read_stdout()?.unwrap_or_default(); let stderr = self.read_stderr()?.unwrap_or_default(); - self.instance.delete()?; + self.instance.delete().block_on()?; log::info!("wasi test status is {status}");