diff --git a/Cargo.lock b/Cargo.lock index 03748cc86..dfdb35805 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -936,11 +936,13 @@ dependencies = [ "tempfile", "thiserror 2.0.11", "tokio", + "tokio-async-drop", "tokio-stream", "toml", "tracing", "tracing-opentelemetry", "tracing-subscriber", + "trait-variant", "ttrpc-codegen", "wasmparser 0.226.0", "wat", diff --git a/crates/containerd-shim-wasm/Cargo.toml b/crates/containerd-shim-wasm/Cargo.toml index 61e191639..364823ae8 100644 --- a/crates/containerd-shim-wasm/Cargo.toml +++ b/crates/containerd-shim-wasm/Cargo.toml @@ -28,7 +28,7 @@ serde_json = { workspace = true } tempfile = { workspace = true, optional = true } thiserror = { workspace = true } wat = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } futures = { version = "0.3.30" } wasmparser = { version = "0.226.0" } tokio-stream = { version = "0.1" } @@ -36,6 +36,8 @@ sha256 = { workspace = true } serde_bytes = "0.11" prost = "0.13" toml = "0.8" +trait-variant = "0.1" +tokio-async-drop = "0.1" # tracing # note: it's important to keep the version of tracing in sync with tracing-subscriber diff --git a/crates/containerd-shim-wasm/README.md b/crates/containerd-shim-wasm/README.md index a628576f3..7716f06e2 100644 --- a/crates/containerd-shim-wasm/README.md +++ b/crates/containerd-shim-wasm/README.md @@ -83,19 +83,19 @@ struct MyInstance { } impl Instance for MyInstance { - fn new(id: String, cfg: &InstanceConfig) -> Result { + async fn new(id: String, cfg: &InstanceConfig) -> Result { Ok(MyInstance { engine: MyEngine }) } - fn start(&self) -> Result { + async fn start(&self) -> Result { Ok(1) } - fn kill(&self, signal: u32) -> Result<(), Error> { + async fn kill(&self, signal: u32) -> Result<(), Error> { Ok(()) } - fn delete(&self) -> Result<(), Error> { + async fn delete(&self) -> Result<(), Error> { Ok(()) } diff --git a/crates/containerd-shim-wasm/src/sandbox/cli.rs b/crates/containerd-shim-wasm/src/sandbox/cli.rs index 1e9f12e9a..30dc5aafb 100644 --- a/crates/containerd-shim-wasm/src/sandbox/cli.rs +++ b/crates/containerd-shim-wasm/src/sandbox/cli.rs @@ -91,6 +91,7 @@ pub mod r#impl { pub use git_version::git_version; } +use super::async_utils::AmbientRuntime as _; pub use crate::{revision, version}; /// Get the crate version from Cargo.toml. @@ -171,15 +172,16 @@ pub fn shim_main<'a, I>( #[cfg(feature = "opentelemetry")] if otel_traces_enabled() { // opentelemetry uses tokio, so we need to initialize a runtime - use tokio::runtime::Runtime; - let rt = Runtime::new().unwrap(); - rt.block_on(async { + async { let otlp_config = OtlpConfig::build_from_env().expect("Failed to build OtelConfig."); let _guard = otlp_config .init() .expect("Failed to initialize OpenTelemetry."); - shim_main_inner::(name, version, revision, shim_version, config); - }); + tokio::task::block_in_place(move || { + shim_main_inner::(name, version, revision, shim_version, config); + }); + } + .block_on(); } else { shim_main_inner::(name, version, revision, shim_version, config); } diff --git a/crates/containerd-shim-wasm/src/sandbox/containerd/lease.rs b/crates/containerd-shim-wasm/src/sandbox/containerd/lease.rs index 589535264..6ca77aa9f 100644 --- a/crates/containerd-shim-wasm/src/sandbox/containerd/lease.rs +++ b/crates/containerd-shim-wasm/src/sandbox/containerd/lease.rs @@ -7,6 +7,7 @@ use containerd_client::services::v1::DeleteRequest; use containerd_client::services::v1::leases_client::LeasesClient; use containerd_client::tonic::transport::Channel; use containerd_client::{tonic, with_namespace}; +use tokio_async_drop::tokio_async_drop; use tonic::Request; // Adds lease info to grpc header @@ -74,7 +75,7 @@ impl LeaseGuardInner { impl Drop for LeaseGuard { fn drop(&mut self) { let inner = self.inner.take().unwrap(); - tokio::spawn(async move { + tokio_async_drop!({ match inner.release().await { Ok(()) => log::info!("removed lease"), Err(err) => log::warn!("error removing lease: {err}"), diff --git a/crates/containerd-shim-wasm/src/sandbox/instance.rs b/crates/containerd-shim-wasm/src/sandbox/instance.rs index c7a2eb1d8..a334179e7 100644 --- a/crates/containerd-shim-wasm/src/sandbox/instance.rs +++ b/crates/containerd-shim-wasm/src/sandbox/instance.rs @@ -32,25 +32,26 @@ pub struct InstanceConfig { /// Instance is a trait that gets implemented by consumers of this library. /// This trait requires that any type implementing it is `'static`, similar to `std::any::Any`. /// This means that the type cannot contain a non-`'static` reference. +#[trait_variant::make(Send)] pub trait Instance: 'static { /// Create a new instance - fn new(id: String, cfg: &InstanceConfig) -> Result + async fn new(id: String, cfg: &InstanceConfig) -> Result where Self: Sized; /// Start the instance /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. - fn start(&self) -> Result; + async fn start(&self) -> Result; /// Send a signal to the instance - fn kill(&self, signal: u32) -> Result<(), Error>; + async fn kill(&self, signal: u32) -> Result<(), Error>; /// Delete any reference to the instance /// This is called after the instance has exited. - fn delete(&self) -> Result<(), Error>; + async fn delete(&self) -> Result<(), Error>; /// Waits for the instance to finish and returns its exit code /// This is an async call. - fn wait(&self) -> impl Future)> + Send; + async fn wait(&self) -> (u32, DateTime); } diff --git a/crates/containerd-shim-wasm/src/sandbox/mod.rs b/crates/containerd-shim-wasm/src/sandbox/mod.rs index 935758c9b..23778bbd2 100644 --- a/crates/containerd-shim-wasm/src/sandbox/mod.rs +++ b/crates/containerd-shim-wasm/src/sandbox/mod.rs @@ -46,19 +46,19 @@ //! } //! //! impl Instance for MyInstance { -//! fn new(id: String, cfg: &InstanceConfig) -> Result { +//! async fn new(id: String, cfg: &InstanceConfig) -> Result { //! Ok(MyInstance { engine: MyEngine }) //! } //! -//! fn start(&self) -> Result { +//! async fn start(&self) -> Result { //! Ok(1) //! } //! -//! fn kill(&self, signal: u32) -> Result<(), Error> { +//! async fn kill(&self, signal: u32) -> Result<(), Error> { //! Ok(()) //! } //! -//! fn delete(&self) -> Result<(), Error> { +//! async fn delete(&self) -> Result<(), Error> { //! Ok(()) //! } //! diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs index 9a1f62f2e..9bb0c1a5f 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs @@ -1,25 +1,26 @@ use std::env::current_dir; use std::fmt::Debug; use std::marker::PhantomData; -use std::sync::Arc; use chrono::Utc; use containerd_shim::error::Error as ShimError; use containerd_shim::publisher::RemotePublisher; use containerd_shim::util::write_address; -use containerd_shim::{self as shim, ExitSignal, api}; +use containerd_shim::{self as shim, api}; use oci_spec::runtime::Spec; use shim::Flags; +use crate::sandbox::async_utils::AmbientRuntime as _; use crate::sandbox::instance::Instance; use crate::sandbox::shim::events::{RemoteEventSender, ToTimestamp}; use crate::sandbox::shim::local::Local; +use crate::sandbox::sync::WaitableCell; /// Cli implements the containerd-shim cli interface using `Local` as the task service. pub struct Cli { namespace: String, containerd_address: String, - exit: Arc, + exit: WaitableCell<()>, _id: String, _phantom: PhantomData, } @@ -48,7 +49,7 @@ where Cli { namespace: args.namespace.to_string(), containerd_address: args.address.clone(), - exit: Arc::default(), + exit: WaitableCell::new(), _id: args.id.to_string(), _phantom: PhantomData, } @@ -77,7 +78,7 @@ where #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] fn wait(&mut self) { - self.exit.wait(); + self.exit.wait().block_on(); } #[cfg_attr( diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs b/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs index 46c973af3..db0f1e80f 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs @@ -1,6 +1,5 @@ -use std::sync::{OnceLock, RwLock}; - use chrono::{DateTime, Utc}; +use tokio::sync::{OnceCell, RwLock}; use crate::sandbox::shim::task_state::TaskState; use crate::sandbox::{Instance, InstanceConfig, Result}; @@ -8,19 +7,22 @@ use crate::sandbox::{Instance, InstanceConfig, Result}; pub(super) struct InstanceData { pub instance: T, pub config: InstanceConfig, - pid: OnceLock, + pid: OnceCell, state: RwLock, } impl InstanceData { #[cfg_attr(feature = "tracing", tracing::instrument(level = "Debug"))] - pub fn new(id: impl AsRef + std::fmt::Debug, config: InstanceConfig) -> Result { + pub async fn new( + id: impl AsRef + std::fmt::Debug, + config: InstanceConfig, + ) -> Result { let id = id.as_ref().to_string(); - let instance = T::new(id, &config)?; + let instance = T::new(id, &config).await?; Ok(Self { instance, config, - pid: OnceLock::default(), + pid: OnceCell::default(), state: RwLock::new(TaskState::Created), }) } @@ -31,11 +33,11 @@ impl InstanceData { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn start(&self) -> Result { - let mut s = self.state.write().unwrap(); + pub async fn start(&self) -> Result { + let mut s = self.state.write().await; s.start()?; - let res = self.instance.start(); + let res = self.instance.start().await; // These state transitions are always `Ok(())` because // we hold the lock since `s.start()` @@ -51,19 +53,19 @@ impl InstanceData { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn kill(&self, signal: u32) -> Result<()> { - let mut s = self.state.write().unwrap(); + pub async fn kill(&self, signal: u32) -> Result<()> { + let mut s = self.state.write().await; s.kill()?; - self.instance.kill(signal) + self.instance.kill(signal).await } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn delete(&self) -> Result<()> { - let mut s = self.state.write().unwrap(); + pub async fn delete(&self) -> Result<()> { + let mut s = self.state.write().await; s.delete()?; - let res = self.instance.delete(); + let res = self.instance.delete().await; if res.is_err() { // Always `Ok(())` because we hold the lock since `s.delete()` @@ -76,7 +78,7 @@ impl InstanceData { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] pub async fn wait(&self) -> (u32, DateTime) { let res = self.instance.wait().await; - let mut s = self.state.write().unwrap(); + let mut s = self.state.write().await; *s = TaskState::Exited; res } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs index 866810877..f999f4084 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs @@ -2,10 +2,7 @@ use std::collections::HashMap; use std::fs::create_dir_all; use std::ops::Not; use std::path::Path; -use std::sync::{Arc, RwLock}; -use std::thread; -#[cfg(feature = "opentelemetry")] -use std::time::Duration; +use std::sync::Arc; use anyhow::ensure; use containerd_shim::api::{ @@ -18,13 +15,14 @@ use containerd_shim::protos::events::task::{TaskCreate, TaskDelete, TaskExit, Ta use containerd_shim::protos::shim::shim_ttrpc::Task; use containerd_shim::protos::types::task::Status; use containerd_shim::util::IntoOption; -use containerd_shim::{DeleteResponse, ExitSignal, TtrpcContext, TtrpcResult}; +use containerd_shim::{DeleteResponse, TtrpcContext, TtrpcResult}; use futures::FutureExt as _; use log::debug; use oci_spec::runtime::Spec; use prost::Message; use protobuf::well_known_types::any::Any; use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; #[cfg(feature = "opentelemetry")] use tracing_opentelemetry::OpenTelemetrySpanExt as _; @@ -34,6 +32,7 @@ use crate::sandbox::async_utils::AmbientRuntime as _; use crate::sandbox::instance::{Instance, InstanceConfig}; use crate::sandbox::shim::events::{EventSender, RemoteEventSender, ToTimestamp}; use crate::sandbox::shim::instance_data::InstanceData; +use crate::sandbox::sync::WaitableCell; use crate::sandbox::{Error, Result, oci}; use crate::sys::metrics::get_metrics; @@ -89,7 +88,7 @@ type LocalInstances = RwLock>>>; pub struct Local { pub(super) instances: LocalInstances, events: E, - exit: Arc, + exit: WaitableCell<()>, namespace: String, containerd_address: String, } @@ -102,7 +101,7 @@ impl Local { )] pub fn new( events: E, - exit: Arc, + exit: WaitableCell<()>, namespace: impl AsRef + std::fmt::Debug, containerd_address: impl AsRef + std::fmt::Debug, ) -> Self { @@ -119,26 +118,26 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub(super) fn get_instance(&self, id: &str) -> Result>> { - let instance = self.instances.read().unwrap().get(id).cloned(); + pub(super) async fn get_instance(&self, id: &str) -> Result>> { + let instance = self.instances.read().await.get(id).cloned(); instance.ok_or_else(|| Error::NotFound(id.to_string())) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn has_instance(&self, id: &str) -> bool { - self.instances.read().unwrap().contains_key(id) + async fn has_instance(&self, id: &str) -> bool { + self.instances.read().await.contains_key(id) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn is_empty(&self) -> bool { - self.instances.read().unwrap().is_empty() + async fn is_empty(&self) -> bool { + self.instances.read().await.is_empty() } } // These are the same functions as in Task, but without the TtrcpContext, which is useful for testing impl Local { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_create(&self, req: CreateTaskRequest) -> Result { + async fn task_create(&self, req: CreateTaskRequest) -> Result { let config = Config::get_from_options(req.options.as_ref()) .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?; @@ -152,7 +151,7 @@ impl Local { )); } - if self.has_instance(&req.id) { + if self.has_instance(&req.id).await { return Err(Error::AlreadyExists(req.id)); } @@ -197,11 +196,11 @@ impl Local { }; // Check if this is a cri container - let instance = InstanceData::new(req.id(), cfg)?; + let instance = InstanceData::new(req.id(), cfg).await?; self.instances .write() - .unwrap() + .await .insert(req.id().to_string(), Arc::new(instance)); self.events.send(TaskCreate { @@ -231,13 +230,13 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_start(&self, req: StartRequest) -> Result { + async fn task_start(&self, req: StartRequest) -> Result { if req.exec_id().is_empty().not() { return Err(ShimError::Unimplemented("exec is not supported".to_string()).into()); } - let i = self.get_instance(req.id())?; - let pid = i.start()?; + let i = self.get_instance(req.id()).await?; + let pid = i.start().await?; self.events.send(TaskStart { container_id: req.id().into(), @@ -271,29 +270,32 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_kill(&self, req: KillRequest) -> Result { + async fn task_kill(&self, req: KillRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - self.get_instance(req.id())?.kill(req.signal())?; + self.get_instance(req.id()) + .await? + .kill(req.signal()) + .await?; Ok(Empty::new()) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_delete(&self, req: DeleteRequest) -> Result { + async fn task_delete(&self, req: DeleteRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).await?; - i.delete()?; + i.delete().await?; let pid = i.pid().unwrap_or_default(); let (exit_code, timestamp) = i.wait().now_or_never().unzip(); let timestamp = timestamp.map(ToTimestamp::to_timestamp); - self.instances.write().unwrap().remove(req.id()); + self.instances.write().await.remove(req.id()); self.events.send(TaskDelete { container_id: req.id().into(), @@ -312,13 +314,13 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_wait(&self, req: WaitRequest) -> Result { + async fn task_wait(&self, req: WaitRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; - let (exit_code, timestamp) = i.wait().block_on(); + let i = self.get_instance(req.id()).await?; + let (exit_code, timestamp) = i.wait().await; debug!("wait finishes"); Ok(WaitResponse { @@ -329,12 +331,12 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_state(&self, req: StateRequest) -> Result { + async fn task_state(&self, req: StateRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).await?; let pid = i.pid(); let (exit_code, timestamp) = i.wait().now_or_never().unzip(); let timestamp = timestamp.map(ToTimestamp::to_timestamp); @@ -361,8 +363,8 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_stats(&self, req: StatsRequest) -> Result { - let i = self.get_instance(req.id())?; + async fn task_stats(&self, req: StatsRequest) -> Result { + let i = self.get_instance(req.id()).await?; let pid = i .pid() .ok_or_else(|| Error::InvalidArgument("task is not running".to_string()))?; @@ -388,7 +390,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_create(req)?) + Ok(self.task_create(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -398,7 +400,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_start(req)?) + Ok(self.task_start(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -408,7 +410,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_kill(req)?) + Ok(self.task_kill(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -418,7 +420,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_delete(req)?) + Ok(self.task_delete(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -426,33 +428,39 @@ impl Task for Local { debug!("wait: {:?}", req); #[cfg(feature = "opentelemetry")] - { + let span_exporter = { use tracing::{Level, Span, span}; let parent_span = Span::current(); parent_span.set_parent(extract_context(&_ctx.metadata)); - let (tx, rx) = std::sync::mpsc::channel(); - // Start a thread to export interval span for long wait - - let _ = thread::spawn(move || { + // This future never completes as it runs an infinite loop. + // It will stop executing when dropped. + // We need to keep this future's lifetime tied to this + // method's lifetime. + // This means we shouldn't tokio::spawn it, but ruther + // tokio::select! it inside of this async method. + async move { loop { let current_span = span!(parent: &parent_span, Level::INFO, "task wait 60s interval"); let _enter = current_span.enter(); - if rx.recv_timeout(Duration::from_secs(60)).is_ok() { - break; - } + tokio::time::sleep(std::time::Duration::from_secs(60)).await; } - }); - let result = self.task_wait(req)?; - tx.send(()).unwrap(); - Ok(result) - } + } + }; #[cfg(not(feature = "opentelemetry"))] - { - Ok(self.task_wait(req)?) + let span_exporter = std::future::pending::<()>(); + + let res = async { + tokio::select! { + _ = span_exporter => unreachable!(), + res = self.task_wait(req) => res, + } } + .block_on()?; + + Ok(res) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -462,7 +470,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).block_on()?; let shim_pid = std::process::id(); let task_pid = i.pid().unwrap_or_default(); Ok(ConnectResponse { @@ -479,7 +487,7 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_state(req)?) + Ok(self.task_state(req).block_on()?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] @@ -489,8 +497,8 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - if self.is_empty() { - self.exit.signal(); + if self.is_empty().block_on() { + let _ = self.exit.set(()); } Ok(Empty::new()) } @@ -502,6 +510,6 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_stats(req)?) + Ok(self.task_stats(req).block_on()?) } } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs b/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs index 9ff5d8562..e1f4bcca9 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs @@ -1,6 +1,4 @@ use std::fs::{File, create_dir}; -use std::sync::mpsc::{Sender, channel}; -use std::thread; use std::time::Duration; use anyhow::Context; @@ -10,6 +8,8 @@ use containerd_shim::event::Event; use protobuf::{MessageDyn, SpecialFields}; use serde_json as json; use tempfile::tempdir; +use tokio::sync::mpsc::{UnboundedSender as Sender, unbounded_channel as channel}; +use tokio_async_drop::tokio_async_drop; use super::*; use crate::sandbox::shim::events::EventSender; @@ -23,19 +23,19 @@ pub struct InstanceStub { } impl Instance for InstanceStub { - fn new(_id: String, _cfg: &InstanceConfig) -> Result { + async fn new(_id: String, _cfg: &InstanceConfig) -> Result { Ok(InstanceStub { exit_code: WaitableCell::new(), }) } - fn start(&self) -> Result { + async fn start(&self) -> Result { Ok(std::process::id()) } - fn kill(&self, _signal: u32) -> Result<(), Error> { + async fn kill(&self, _signal: u32) -> Result<(), Error> { let _ = self.exit_code.set((1, Utc::now())); Ok(()) } - fn delete(&self) -> Result<(), Error> { + async fn delete(&self) -> Result<(), Error> { Ok(()) } async fn wait(&self) -> (u32, DateTime) { @@ -61,15 +61,13 @@ impl EventSender for Sender<(String, Box)> { impl Drop for LocalWithDestructor { fn drop(&mut self) { - self.local - .instances - .write() - .unwrap() - .iter() - .for_each(|(_, v)| { - let _ = v.kill(9); - v.delete().unwrap(); - }); + tokio_async_drop!({ + let instances = self.local.instances.write().await; + for (_, instance) in instances.iter() { + let _ = instance.kill(9).await; + let _ = instance.delete().await; + } + }) } } @@ -97,8 +95,10 @@ fn create_bundle(dir: &std::path::Path, spec: Option) -> Result<()> { Ok(()) } -#[test] -fn test_delete_after_create() { +// Use a multi threaded runtime because LocalWithDestructor needs +// it to run its async drop. +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_delete_after_create() -> anyhow::Result<()> { let dir = tempdir().unwrap(); let id = "test-delete-after-create"; create_bundle(dir.path(), None).unwrap(); @@ -106,7 +106,7 @@ fn test_delete_after_create() { let (tx, _rx) = channel(); let local = Arc::new(Local::::new( tx, - Arc::new(ExitSignal::default()), + WaitableCell::new(), "test_namespace", "/test/address", )); @@ -118,22 +118,26 @@ fn test_delete_after_create() { bundle: dir.path().to_str().unwrap().to_string(), ..Default::default() }) - .unwrap(); + .await?; local .task_delete(DeleteRequest { id: id.to_string(), ..Default::default() }) - .unwrap(); + .await?; + + Ok(()) } -#[test] -fn test_cri_task() -> Result<()> { +// Use a multi threaded runtime because LocalWithDestructor needs +// it to run its async drop. +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_cri_task() -> Result<()> { // Currently the relationship between the "base" container and the "instances" are pretty weak. // When a cri sandbox is specified we just assume it's the sandbox container and treat it as such by not actually running the code (which is going to be wasm). let (etx, _erx) = channel(); - let exit_signal = Arc::new(ExitSignal::default()); + let exit_signal = WaitableCell::new(); let local = Arc::new(Local::::new( etx, exit_signal, @@ -148,39 +152,49 @@ fn test_cri_task() -> Result<()> { let sandbox_id = "test-cri-task".to_string(); create_bundle(dir, Some(with_cri_sandbox(None, sandbox_id.clone())))?; - local.task_create(CreateTaskRequest { - id: "testbase".to_string(), - bundle: dir.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "testbase".to_string(), + bundle: dir.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); // make sure that the instance exists - let _i = local.get_instance("testbase")?; + let _i = local.get_instance("testbase").await?; - local.task_start(StartRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); let ll = local.clone(); - let (base_tx, base_rx) = channel(); - thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "testbase".to_string(), - ..Default::default() - }); + let (base_tx, mut base_rx) = channel(); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await; base_tx.send(resp).unwrap(); }); base_rx.try_recv().unwrap_err(); @@ -189,72 +203,96 @@ fn test_cri_task() -> Result<()> { let dir2 = temp2.path(); create_bundle(dir2, Some(with_cri_sandbox(None, sandbox_id)))?; - local.task_create(CreateTaskRequest { - id: "testinstance".to_string(), - bundle: dir2.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "testinstance".to_string(), + bundle: dir2.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); // make sure that the instance exists - let _i = local.get_instance("testinstance")?; + let _i = local.get_instance("testinstance").await?; - local.task_start(StartRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - let stats = local.task_stats(StatsRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let stats = local + .task_stats(StatsRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert!(stats.has_stats()); let ll = local.clone(); - let (instance_tx, instance_rx) = channel(); - std::thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "testinstance".to_string(), - ..Default::default() - }); + let (instance_tx, mut instance_rx) = channel(); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await; instance_tx.send(resp).unwrap(); }); instance_rx.try_recv().unwrap_err(); - local.task_kill(KillRequest { - id: "testinstance".to_string(), - signal: 9, - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "testinstance".to_string(), + signal: 9, + ..Default::default() + }) + .await?; - instance_rx.recv_timeout(Duration::from_secs(50)).unwrap()?; + instance_rx + .recv() + .with_timeout(Duration::from_secs(50)) + .await + .flatten() + .unwrap()?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "testinstance".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} @@ -262,34 +300,48 @@ fn test_cri_task() -> Result<()> { } base_rx.try_recv().unwrap_err(); - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - local.task_kill(KillRequest { - id: "testbase".to_string(), - signal: 9, - ..Default::default() - })?; - - base_rx.recv_timeout(Duration::from_secs(5)).unwrap()?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "testbase".to_string(), + signal: 9, + ..Default::default() + }) + .await?; + + base_rx + .recv() + .with_timeout(Duration::from_secs(5)) + .await + .flatten() + .unwrap()?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "testbase".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} @@ -299,10 +351,12 @@ fn test_cri_task() -> Result<()> { Ok(()) } -#[test] -fn test_task_lifecycle() -> Result<()> { +// Use a multi threaded runtime because LocalWithDestructor needs +// it to run its async drop. +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_task_lifecycle() -> Result<()> { let (etx, _erx) = channel(); // TODO: check events - let exit_signal = Arc::new(ExitSignal::default()); + let exit_signal = WaitableCell::new(); let local = Arc::new(Local::::new( etx, exit_signal, @@ -321,17 +375,20 @@ fn test_task_lifecycle() -> Result<()> { id: "test".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} e => return Err(e), } - local.task_create(CreateTaskRequest { - id: "test".to_string(), - bundle: dir.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "test".to_string(), + bundle: dir.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; match local .task_create(CreateTaskRequest { @@ -339,73 +396,95 @@ fn test_task_lifecycle() -> Result<()> { bundle: dir.to_str().unwrap().to_string(), ..Default::default() }) + .await .unwrap_err() { Error::AlreadyExists(_) => {} e => return Err(e), } - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); - local.task_start(StartRequest { - id: "test".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - let (tx, rx) = channel(); + let (tx, mut rx) = channel(); let ll = local.clone(); - thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "test".to_string(), - ..Default::default() - }); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "test".to_string(), + ..Default::default() + }) + .await; tx.send(resp).unwrap(); }); rx.try_recv().unwrap_err(); - let res = local.task_stats(StatsRequest { - id: "test".to_string(), - ..Default::default() - })?; + let res = local + .task_stats(StatsRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert!(res.has_stats()); - local.task_kill(KillRequest { - id: "test".to_string(), - signal: 9, - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "test".to_string(), + signal: 9, + ..Default::default() + }) + .await?; - rx.recv_timeout(Duration::from_secs(5)).unwrap()?; + rx.recv() + .with_timeout(Duration::from_secs(5)) + .await + .flatten() + .unwrap()?; - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "test".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "test".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} diff --git a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs index 656314863..bb6c26213 100644 --- a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs @@ -30,11 +30,11 @@ pub struct Instance { impl SandboxInstance for Instance { #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn new(id: String, cfg: &InstanceConfig) -> Result { + async fn new(id: String, cfg: &InstanceConfig) -> Result { // check if container is OCI image with wasm layers and attempt to read the module - let (modules, platform) = containerd::Client::connect(&cfg.containerd_address, &cfg.namespace).block_on()? + let (modules, platform) = containerd::Client::connect(&cfg.containerd_address, &cfg.namespace).await? .load_modules(&id, &E::default()) - .block_on() + .await .unwrap_or_else(|e| { log::warn!("Error obtaining wasm layers for container {id}. Will attempt to use files inside container image. Error: {e}"); (vec![], Platform::default()) @@ -83,7 +83,7 @@ impl SandboxInstance for Instance { /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn start(&self) -> Result { + async fn start(&self) -> Result { log::info!("starting instance: {}", self.id); // make sure we have an exit code by the time we finish (even if there's a panic) let guard = self.exit_code.clone().set_guard_with(|| (137, Utc::now())); @@ -123,7 +123,7 @@ impl SandboxInstance for Instance { /// Send a signal to the instance #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn kill(&self, signal: u32) -> Result<(), SandboxError> { + async fn kill(&self, signal: u32) -> Result<(), SandboxError> { log::info!("sending signal {signal} to instance: {}", self.id); self.container.kill(signal)?; Ok(()) @@ -132,7 +132,7 @@ impl SandboxInstance for Instance { /// Delete any reference to the instance /// This is called after the instance has exited. #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn delete(&self) -> Result<(), SandboxError> { + async fn delete(&self) -> Result<(), SandboxError> { log::info!("deleting instance: {}", self.id); self.container.delete()?; Ok(()) diff --git a/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs b/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs index dc7aa0eba..afdf4c5ab 100644 --- a/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs @@ -9,25 +9,25 @@ use crate::sandbox::{Error as SandboxError, Instance as SandboxInstance, Instanc pub struct Instance(PhantomData); impl SandboxInstance for Instance { - fn new(_id: String, _cfg: &InstanceConfig) -> Result { + async fn new(_id: String, _cfg: &InstanceConfig) -> Result { todo!(); } /// Start the instance /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. - fn start(&self) -> Result { + async fn start(&self) -> Result { todo!(); } /// Send a signal to the instance - fn kill(&self, _signal: u32) -> Result<(), SandboxError> { + async fn kill(&self, _signal: u32) -> Result<(), SandboxError> { todo!(); } /// Delete any reference to the instance /// This is called after the instance has exited. - fn delete(&self) -> Result<(), SandboxError> { + async fn delete(&self) -> Result<(), SandboxError> { todo!(); } diff --git a/crates/containerd-shim-wasm/src/testing.rs b/crates/containerd-shim-wasm/src/testing.rs index d8ffdc724..e179d9e4c 100644 --- a/crates/containerd-shim-wasm/src/testing.rs +++ b/crates/containerd-shim-wasm/src/testing.rs @@ -84,6 +84,7 @@ impl WasiTestBuilder { // Removing the `network` namespace results in the binding to the host's socket. // This allows for direct communication with the host's networking interface. self.namespaces + // typos:disable-next-line - false positive "typ" .retain(|ns| ns.typ() != LinuxNamespaceType::Network); self } @@ -216,7 +217,7 @@ impl WasiTestBuilder { ..Default::default() }; - let instance = WasiInstance::new(self.container_name, &cfg)?; + let instance = WasiInstance::new(self.container_name, &cfg).block_on()?; Ok(WasiTest { instance, tempdir }) } } @@ -232,7 +233,7 @@ impl WasiTest { pub fn start(&self) -> Result<&Self> { log::info!("starting wasi test"); - let pid = self.instance.start()?; + let pid = self.instance.start().block_on()?; log::info!("wasi test pid {pid}"); Ok(self) @@ -240,25 +241,25 @@ impl WasiTest { pub fn delete(&self) -> Result<&Self> { log::info!("deleting wasi test"); - self.instance.delete()?; + self.instance.delete().block_on()?; Ok(self) } pub fn ctrl_c(&self) -> Result<&Self> { log::info!("sending SIGINT"); - self.instance.kill(SIGINT as u32)?; + self.instance.kill(SIGINT as u32).block_on()?; Ok(self) } pub fn terminate(&self) -> Result<&Self> { log::info!("sending SIGTERM"); - self.instance.kill(SIGTERM as u32)?; + self.instance.kill(SIGTERM as u32).block_on()?; Ok(self) } pub fn kill(&self) -> Result<&Self> { log::info!("sending SIGKILL"); - self.instance.kill(SIGKILL as u32)?; + self.instance.kill(SIGKILL as u32).block_on()?; Ok(self) } @@ -267,7 +268,7 @@ impl WasiTest { let (status, _) = match self.instance.wait().with_timeout(t).block_on() { Some(res) => res, None => { - self.instance.kill(SIGKILL)?; + self.instance.kill(SIGKILL).block_on()?; bail!("timeout while waiting for module to finish"); } }; @@ -275,7 +276,7 @@ impl WasiTest { let stdout = self.read_stdout()?.unwrap_or_default(); let stderr = self.read_stderr()?.unwrap_or_default(); - self.instance.delete()?; + self.instance.delete().block_on()?; log::info!("wasi test status is {status}"); diff --git a/typos.toml b/typos.toml index 0a4f824fc..fbde07b4f 100644 --- a/typos.toml +++ b/typos.toml @@ -1,2 +1,11 @@ +[default] +# allow disabling the next line, e.g: +# self.namespaces +# // typos:disable-next-line - false positive "typ" +# .retain(|ns| ns.typ() != LinuxNamespaceType::Network); +extend-ignore-re = [ + "(?Rm)^\\s*//\\s*typos:disable-next-line(\\s.*|\\s*)(\\r|\\n)+[^\\r\\n]*$", +] + [files] extend-exclude = ["docs/mermaid.min.js"] \ No newline at end of file