diff --git a/hydro_lang/Cargo.toml b/hydro_lang/Cargo.toml index 0ebfe55e1e67..66a4e6b0cceb 100644 --- a/hydro_lang/Cargo.toml +++ b/hydro_lang/Cargo.toml @@ -22,6 +22,7 @@ build = [ "dep:dfir_lang" ] [dependencies] bincode = "1.3.1" +ctor = "0.2.8" hydro_deploy = { path = "../hydro_deploy/core", version = "^0.11.0", optional = true } dfir_rs = { path = "../dfir_rs", version = "^0.11.0", default-features = false, features = ["deploy_integration"] } dfir_lang = { path = "../dfir_lang", version = "^0.11.0", optional = true } @@ -46,7 +47,6 @@ stageleft_tool = { path = "../stageleft_tool", version = "^0.5.0" } [dev-dependencies] async-ssh2-lite = { version = "0.5.0", features = ["vendored-openssl"] } -ctor = "0.2.8" hydro_deploy = { path = "../hydro_deploy/core", version = "^0.11.0" } insta = "1.39" tokio-test = "0.4.4" diff --git a/hydro_lang/src/lib.rs b/hydro_lang/src/lib.rs index 1f434e6bd27f..438ffcc73bf2 100644 --- a/hydro_lang/src/lib.rs +++ b/hydro_lang/src/lib.rs @@ -46,6 +46,12 @@ mod staging_util; #[cfg(feature = "deploy")] pub mod test_util; +#[ctor::ctor] +fn add_private_reexports() { + stageleft::add_private_reexport(vec!["tokio", "time", "instant"], vec!["tokio", "time"]); + stageleft::add_private_reexport(vec!["bytes", "bytes"], vec!["bytes"]); +} + #[stageleft::runtime] #[cfg(test)] mod tests { diff --git a/stageleft/src/lib.rs b/stageleft/src/lib.rs index 8ee222bce757..b8c662be384e 100644 --- a/stageleft/src/lib.rs +++ b/stageleft/src/lib.rs @@ -22,7 +22,7 @@ use runtime_support::FreeVariableWithContext; use crate::runtime_support::get_final_crate_name; mod type_name; -pub use type_name::quote_type; +pub use type_name::{add_private_reexport, quote_type}; #[cfg(windows)] #[macro_export] diff --git a/stageleft/src/type_name.rs b/stageleft/src/type_name.rs index 486fb866f0a5..55546afdb481 100644 --- a/stageleft/src/type_name.rs +++ b/stageleft/src/type_name.rs @@ -1,152 +1,82 @@ +use std::sync::{LazyLock, RwLock}; + use proc_macro2::Span; use syn::visit_mut::VisitMut; use syn::{parse_quote, TypeInfer}; use crate::runtime_support::get_final_crate_name; -/// Rewrites use of alloc::string::* to use std::string::* -struct RewriteAlloc { +type ReexportsSet = LazyLock, Vec<&'static str>)>>>; +static PRIVATE_REEXPORTS: ReexportsSet = LazyLock::new(|| { + RwLock::new(vec![ + (vec!["alloc"], vec!["std"]), + (vec!["core", "ops", "range"], vec!["std", "ops"]), + (vec!["core", "slice", "iter"], vec!["std", "slice"]), + (vec!["core", "iter", "adapters", "*"], vec!["std", "iter"]), + ( + vec!["std", "collections", "hash", "map"], + vec!["std", "collections", "hash_map"], + ), + (vec!["std", "vec", "into_iter"], vec!["std", "vec"]), + ]) +}); + +/// Adds a private module re-export transformation to the type quoting system. +/// +/// Sometimes, the [`quote_type`] function may produce an uncompilable reference to a +/// type inside a private module if the type is re-exported from a public module +/// (because Rust's `type_name` only gives the path to the original definition). +/// +/// This function adds a rewrite rule for such cases, where the `from` path is +/// replaced with the `to` path. The paths are given as a list of strings, where +/// each string is a segment of the path. The `from` path is matched against the +/// beginning of the type path, and if it matches, the `to` path is substituted +/// in its place. The `from` path may contain a wildcard `*` to glob a segment. +/// +/// # Example +/// ```rust +/// stageleft::add_private_reexport( +/// vec!["std", "collections", "hash", "map"], +/// vec!["std", "collections", "hash_map"], +/// ); +/// ``` +pub fn add_private_reexport(from: Vec<&'static str>, to: Vec<&'static str>) { + let mut transformations = PRIVATE_REEXPORTS.write().unwrap(); + transformations.push((from, to)); +} + +struct RewritePrivateReexports { mapping: Option<(String, String)>, } -impl VisitMut for RewriteAlloc { +impl VisitMut for RewritePrivateReexports { fn visit_path_mut(&mut self, i: &mut syn::Path) { - if i.segments.iter().take(1).collect::>() - == vec![&syn::PathSegment::from(syn::Ident::new( - "alloc", - Span::call_site(), - ))] - { - *i.segments.first_mut().unwrap() = - syn::PathSegment::from(syn::Ident::new("std", Span::call_site())); - } else if i.segments.iter().take(3).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("core", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("ops", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("range", Span::call_site())), - ] - { - *i = syn::Path { - leading_colon: i.leading_colon, - segments: syn::punctuated::Punctuated::from_iter( - vec![ - syn::PathSegment::from(syn::Ident::new("std", Span::call_site())), - syn::PathSegment::from(syn::Ident::new("ops", Span::call_site())), - ] - .into_iter() - .chain(i.segments.iter().skip(3).cloned()), - ), - }; - } else if i.segments.iter().take(3).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("core", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("slice", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("iter", Span::call_site())), - ] - { - *i = syn::Path { - leading_colon: i.leading_colon, - segments: syn::punctuated::Punctuated::from_iter( - vec![ - syn::PathSegment::from(syn::Ident::new("std", Span::call_site())), - syn::PathSegment::from(syn::Ident::new("slice", Span::call_site())), - ] - .into_iter() - .chain(i.segments.iter().skip(3).cloned()), - ), - }; - } else if i.segments.iter().take(3).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("core", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("iter", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("adapters", Span::call_site())), - ] - { - *i = syn::Path { - leading_colon: i.leading_colon, - segments: syn::punctuated::Punctuated::from_iter( - vec![ - syn::PathSegment::from(syn::Ident::new("std", Span::call_site())), - syn::PathSegment::from(syn::Ident::new("iter", Span::call_site())), - ] - .into_iter() - .chain(i.segments.iter().skip(4).cloned()), - ), - }; - } else if i.segments.iter().take(4).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("std", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("collections", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("hash", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("map", Span::call_site())), - ] - { - *i = syn::Path { - leading_colon: i.leading_colon, - segments: syn::punctuated::Punctuated::from_iter( - vec![ - syn::PathSegment::from(syn::Ident::new("std", Span::call_site())), - syn::PathSegment::from(syn::Ident::new("collections", Span::call_site())), - syn::PathSegment::from(syn::Ident::new("hash_map", Span::call_site())), - ] - .into_iter() - .chain(i.segments.iter().skip(4).cloned()), - ), - }; - } else if i.segments.iter().take(3).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("std", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("vec", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("into_iter", Span::call_site())), - ] - { - *i = syn::Path { - leading_colon: i.leading_colon, - segments: syn::punctuated::Punctuated::from_iter( - vec![ - syn::PathSegment::from(syn::Ident::new("std", Span::call_site())), - syn::PathSegment::from(syn::Ident::new("vec", Span::call_site())), - ] - .into_iter() - .chain(i.segments.iter().skip(3).cloned()), - ), - }; - } else if i.segments.iter().take(3).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("tokio", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("time", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("instant", Span::call_site())), - ] - { - *i = syn::Path { - leading_colon: i.leading_colon, - segments: syn::punctuated::Punctuated::from_iter( - vec![ - syn::PathSegment::from(syn::Ident::new("tokio", Span::call_site())), - syn::PathSegment::from(syn::Ident::new("time", Span::call_site())), - ] - .into_iter() - .chain(i.segments.iter().skip(3).cloned()), - ), - }; - } else if i.segments.iter().take(2).collect::>() - == vec![ - &syn::PathSegment::from(syn::Ident::new("bytes", Span::call_site())), - &syn::PathSegment::from(syn::Ident::new("bytes", Span::call_site())), - ] - { - *i = syn::Path { - leading_colon: i.leading_colon, - segments: syn::punctuated::Punctuated::from_iter( - vec![syn::PathSegment::from(syn::Ident::new( - "bytes", - Span::call_site(), - ))] - .into_iter() - .chain(i.segments.iter().skip(2).cloned()), - ), - }; - } else if let Some((macro_name, final_name)) = &self.mapping { + let transformations = PRIVATE_REEXPORTS.read().unwrap(); + for (from, to) in transformations.iter() { + #[expect(clippy::cmp_owned, reason = "buggy lint for syn::Ident::to_string")] + if i.segments.len() >= from.len() + && from + .iter() + .zip(i.segments.iter()) + .all(|(f, s)| *f == "*" || *f == s.ident.to_string()) + { + *i = syn::Path { + leading_colon: i.leading_colon, + segments: syn::punctuated::Punctuated::from_iter( + to.iter() + .map(|s| syn::PathSegment::from(syn::Ident::new(s, Span::call_site()))) + .chain(i.segments.iter().skip(from.len()).cloned()), + ), + }; + + drop(transformations); + self.visit_path_mut(i); + return; + } + } + drop(transformations); + + if let Some((macro_name, final_name)) = &self.mapping { if i.segments.first().unwrap().ident == macro_name { *i.segments.first_mut().unwrap() = syn::parse2(get_final_crate_name(final_name)).unwrap(); @@ -154,14 +84,10 @@ impl VisitMut for RewriteAlloc { i.segments.insert(1, parse_quote!(__staged)); } else { syn::visit_mut::visit_path_mut(self, i); - return; } } else { syn::visit_mut::visit_path_mut(self, i); - return; } - - self.visit_path_mut(i); } } @@ -202,7 +128,7 @@ pub fn quote_type() -> syn::Type { }); let mapping = super::runtime_support::MACRO_TO_CRATE.with(|m| m.borrow().clone()); ElimClosureToInfer.visit_type_mut(&mut t_type); - RewriteAlloc { mapping }.visit_type_mut(&mut t_type); + RewritePrivateReexports { mapping }.visit_type_mut(&mut t_type); t_type } diff --git a/stageleft_tool/src/lib.rs b/stageleft_tool/src/lib.rs index 2af956a2e935..a925245d7eb3 100644 --- a/stageleft_tool/src/lib.rs +++ b/stageleft_tool/src/lib.rs @@ -212,7 +212,14 @@ impl VisitMut for GenFinalPubVistor { } } - i.vis = parse_quote!(pub); + let is_ctor = i + .attrs + .iter() + .any(|a| a.path().to_token_stream().to_string() == "ctor :: ctor"); + + if !is_ctor { + i.vis = parse_quote!(pub); + } syn::visit_mut::visit_item_fn_mut(self, i); }