Skip to content

Commit

Permalink
Arc<RwLock<Module>>
Browse files Browse the repository at this point in the history
Signed-off-by: he1pa <18012015693@163.com>
  • Loading branch information
He1pa committed Oct 29, 2024
1 parent ed7d00e commit c9ee296
Show file tree
Hide file tree
Showing 36 changed files with 541 additions and 328 deletions.
64 changes: 46 additions & 18 deletions kclvm/ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
use kclvm_utils::path::PathPrefix;
use serde::{ser::SerializeStruct, Deserialize, Serialize, Serializer};
use std::collections::HashMap;
use std::{
collections::HashMap,
sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard},
};

use compiler_base_span::{Loc, Span};
use std::fmt::Debug;
Expand Down Expand Up @@ -383,63 +386,88 @@ pub struct SerializeProgram {
impl Into<SerializeProgram> for Program {
fn into(self) -> SerializeProgram {
SerializeProgram {
root: self.root,
root: self.root.clone(),
pkgs: self
.pkgs
.iter()
.map(|(name, modules)| {
(
name.clone(),
modules.iter().map(|m| m.as_ref().clone()).collect(),
modules
.iter()
.map(|m| self.get_module(m).unwrap().unwrap().clone())
.collect(),
)
})
.collect(),
}
}
}

#[derive(Debug, Clone, Default, PartialEq)]
#[derive(Debug, Clone, Default)]
pub struct Program {
pub root: String,
pub pkgs: HashMap<String, Vec<Arc<Module>>>,
pub pkgs: HashMap<String, Vec<String>>,
pub modules: HashMap<String, Arc<RwLock<Module>>>,
}

impl Program {
/// Get main entry files.
pub fn get_main_files(&self) -> Vec<String> {
match self.pkgs.get(crate::MAIN_PKG) {
Some(modules) => modules.iter().map(|m| m.filename.clone()).collect(),
Some(modules) => modules.clone(),
None => vec![],
}
}
/// Get the first module in the main package.
pub fn get_main_package_first_module(&self) -> Option<Arc<Module>> {
pub fn get_main_package_first_module(&self) -> Option<RwLockReadGuard<'_, Module>> {
match self.pkgs.get(crate::MAIN_PKG) {
Some(modules) => modules.first().cloned(),
Some(modules) => match modules.first() {
Some(first_module_path) => self.get_module(&first_module_path).unwrap_or(None),
None => None,
},
None => None,
}
}
/// Get stmt on position
pub fn pos_to_stmt(&self, pos: &Position) -> Option<Node<Stmt>> {
for (_, v) in &self.pkgs {
for m in v {
if m.filename == pos.filename {
return m.pos_to_stmt(pos);
if let Ok(m) = self.get_module(m) {
let m = m?;
if m.filename == pos.filename {
return m.pos_to_stmt(pos);
}
}
}
}
None
}

pub fn get_module(&self, module_name: &str) -> Option<&Module> {
for (_, modules) in &self.pkgs {
for module in modules {
if module.filename == module_name {
return Some(module);
}
}
pub fn get_module(
&self,
module_path: &str,
) -> Result<Option<RwLockReadGuard<'_, Module>>, &str> {
match self.modules.get(module_path) {
Some(module_ref) => match module_ref.read() {
Ok(m) => Ok(Some(m)),
Err(_) => Err("Failed to acquire module lock"),
},
None => Ok(None),
}
}

pub fn get_mut_module(
&self,
module_path: &str,
) -> Result<Option<RwLockWriteGuard<'_, Module>>, &str> {
match self.modules.get(module_path) {
Some(module_ref) => match module_ref.write() {
Ok(m) => Ok(Some(m)),
Err(_) => Err("Failed to acquire module lock"),
},
None => Ok(None),
}
None
}
}

Expand Down
4 changes: 2 additions & 2 deletions kclvm/evaluator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl<'ctx> Evaluator<'ctx> {
self.init_scope(kclvm_ast::MAIN_PKG);
let modules: Vec<Module> = modules
.iter()
.map(|arc| arc.clone().as_ref().clone())
.map(|m| self.program.get_module(m).unwrap().unwrap().clone())
.collect();
self.compile_ast_modules(&modules);
}
Expand All @@ -171,7 +171,7 @@ impl<'ctx> Evaluator<'ctx> {
self.init_scope(kclvm_ast::MAIN_PKG);
let modules: Vec<Module> = modules
.iter()
.map(|arc| arc.clone().as_ref().clone())
.map(|m| self.program.get_module(m).unwrap().unwrap().clone())
.collect();
self.compile_ast_modules(&modules)
} else {
Expand Down
3 changes: 2 additions & 1 deletion kclvm/evaluator/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl<'ctx> TypedResultWalker<'ctx> for Evaluator<'ctx> {
self.init_scope(&pkgpath);
let modules: Vec<Module> = modules
.iter()
.map(|arc| arc.clone().as_ref().clone())
.map(|m| self.program.get_module(&m).unwrap().unwrap().clone())
.collect();
self.compile_ast_modules(&modules);
self.pop_pkgpath();
Expand Down Expand Up @@ -1158,6 +1158,7 @@ impl<'ctx> Evaluator<'ctx> {
.get(&pkgpath_without_prefix!(frame.pkgpath))
{
if let Some(module) = module_list.get(*index) {
let module = self.program.get_module(module).unwrap().unwrap();
if let Some(stmt) = module.body.get(setter.stmt) {
self.push_backtrack_meta(setter);
self.walk_stmt(stmt).expect(INTERNAL_ERROR_MSG);
Expand Down
4 changes: 2 additions & 2 deletions kclvm/evaluator/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<'ctx> Evaluator<'ctx> {
let modules = self.program.pkgs.get(pkgpath).expect(&msg);
let modules: Vec<Module> = modules
.iter()
.map(|arc| arc.clone().as_ref().clone())
.map(|m| self.program.get_module(m).unwrap().unwrap().clone())
.collect();
modules
} else if pkgpath.starts_with(kclvm_runtime::PKG_PATH_PREFIX)
Expand All @@ -56,7 +56,7 @@ impl<'ctx> Evaluator<'ctx> {
.expect(kcl_error::INTERNAL_ERROR_MSG);
let modules: Vec<Module> = modules
.iter()
.map(|arc| arc.clone().as_ref().clone())
.map(|m| self.program.get_module(m).unwrap().unwrap().clone())
.collect();
modules
} else {
Expand Down
3 changes: 2 additions & 1 deletion kclvm/loader/src/option.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ pub fn list_options(opts: &LoadPackageOptions) -> Result<Vec<OptionHelp>> {
for (pkgpath, modules) in &packages.program.pkgs {
extractor.pkgpath = pkgpath.clone();
for module in modules {
extractor.walk_module(module)
let module = packages.program.get_module(module).unwrap().unwrap();
extractor.walk_module(&module)
}
}
Ok(extractor.options)
Expand Down
53 changes: 32 additions & 21 deletions kclvm/parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pub fn parse_single_file(filename: &str, code: Option<String>) -> Result<ParseFi
);
let result = loader.load_main()?;
let module = match result.program.get_main_package_first_module() {
Some(module) => module.as_ref().clone(),
Some(module) => module.clone(),
None => ast::Module::default(),
};
let file_graph = match loader.file_graph.read() {
Expand Down Expand Up @@ -317,7 +317,7 @@ pub type KCLModuleCache = Arc<RwLock<ModuleCache>>;

#[derive(Default, Debug)]
pub struct ModuleCache {
pub ast_cache: IndexMap<PathBuf, Arc<ast::Module>>,
pub ast_cache: IndexMap<PathBuf, Arc<RwLock<ast::Module>>>,
pub file_pkg: IndexMap<PathBuf, HashSet<PkgFile>>,
pub dep_cache: IndexMap<PkgFile, (Vec<PkgFile>, PkgMap)>,
}
Expand Down Expand Up @@ -666,24 +666,21 @@ pub fn parse_file(
file: PkgFile,
src: Option<String>,
module_cache: KCLModuleCache,
pkgs: &mut HashMap<String, Vec<Arc<Module>>>,
pkgs: &mut HashMap<String, Vec<String>>,
pkgmap: &mut PkgMap,
file_graph: FileGraphCache,
opts: &LoadProgramOptions,
) -> Result<Vec<PkgFile>> {
let m = Arc::new(parse_file_with_session(
sess.clone(),
file.path.to_str().unwrap(),
src,
)?);
let m = parse_file_with_session(sess.clone(), file.path.to_str().unwrap(), src)?;

let (deps, new_pkgmap) = get_deps(&file, m.as_ref(), pkgs, pkgmap, opts, sess)?;
let (deps, new_pkgmap) = get_deps(&file, &m, pkgs, pkgmap, opts, sess)?;
let m_ref = Arc::new(RwLock::new(m));
pkgmap.extend(new_pkgmap.clone());
match &mut module_cache.write() {
Ok(module_cache) => {
module_cache
.ast_cache
.insert(file.canonicalize(), m.clone());
.insert(file.canonicalize(), m_ref.clone());
match module_cache.file_pkg.get_mut(&file.canonicalize()) {
Some(s) => {
s.insert(file.clone());
Expand Down Expand Up @@ -713,7 +710,7 @@ pub fn parse_file(
pub fn get_deps(
file: &PkgFile,
m: &Module,
modules: &mut HashMap<String, Vec<Arc<Module>>>,
modules: &mut HashMap<String, Vec<String>>,
pkgmap: &mut PkgMap,
opts: &LoadProgramOptions,
sess: ParseSessionRef,
Expand Down Expand Up @@ -773,7 +770,7 @@ pub fn parse_pkg(
sess: ParseSessionRef,
files: Vec<(PkgFile, Option<String>)>,
module_cache: KCLModuleCache,
pkgs: &mut HashMap<String, Vec<Arc<Module>>>,
pkgs: &mut HashMap<String, Vec<String>>,
pkgmap: &mut PkgMap,
file_graph: FileGraphCache,
opts: &LoadProgramOptions,
Expand All @@ -799,7 +796,7 @@ pub fn parse_entry(
sess: ParseSessionRef,
entry: &entry::Entry,
module_cache: KCLModuleCache,
pkgs: &mut HashMap<String, Vec<Arc<Module>>>,
pkgs: &mut HashMap<String, Vec<String>>,
pkgmap: &mut PkgMap,
file_graph: FileGraphCache,
opts: &LoadProgramOptions,
Expand Down Expand Up @@ -860,8 +857,15 @@ pub fn parse_entry(
Some(m) => {
let (deps, new_pkgmap) =
m_cache.dep_cache.get(&file).cloned().unwrap_or_else(|| {
get_deps(&file, m.as_ref(), pkgs, pkgmap, opts, sess.clone())
.unwrap()
get_deps(
&file,
&m.read().unwrap(),
pkgs,
pkgmap,
opts,
sess.clone(),
)
.unwrap()
});
pkgmap.extend(new_pkgmap.clone());

Expand Down Expand Up @@ -916,7 +920,7 @@ pub fn parse_program(
) -> Result<LoadProgramResult> {
let compile_entries = get_compile_entries_from_paths(&paths, &opts)?;
let workdir = compile_entries.get_root_path().to_string();
let mut pkgs: HashMap<String, Vec<Arc<Module>>> = HashMap::new();
let mut pkgs: HashMap<String, Vec<String>> = HashMap::new();
let mut pkgmap = PkgMap::new();
let mut new_files = HashSet::new();

Expand Down Expand Up @@ -964,7 +968,10 @@ pub fn parse_program(
}
Err(e) => return Err(anyhow::anyhow!("Parse program failed: {e}")),
};

let mut modules: HashMap<String, Arc<RwLock<Module>>> = HashMap::new();
for file in files.iter() {
let filename = file.canonicalize().to_str().unwrap().to_string();
let mut m_ref = match module_cache.read() {
Ok(module_cache) => module_cache
.ast_cache
Expand All @@ -978,23 +985,27 @@ pub fn parse_program(
};
if new_files.contains(file) {
let pkg = pkgmap.get(file).expect("file not in pkgmap");
let mut m = Arc::make_mut(&mut m_ref);
let mut m = m_ref.write().unwrap();
fix_rel_import_path_with_file(&pkg.pkg_root, &mut m, file, &pkgmap, opts, sess.clone());
}

modules.insert(filename, m_ref);
match pkgs.get_mut(&file.pkg_path) {
Some(modules) => {
modules.push(m_ref);
Some(pkg_modules) => {
pkg_modules.push(file.path.to_str().unwrap().to_string());
}
None => {
pkgs.insert(file.pkg_path.clone(), vec![m_ref]);
pkgs.insert(
file.pkg_path.clone(),
vec![file.path.to_str().unwrap().to_string()],
);
}
}
}

let program = ast::Program {
root: workdir,
pkgs,
modules,
};

Ok(LoadProgramResult {
Expand Down
18 changes: 12 additions & 6 deletions kclvm/parser/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,18 @@ pub fn test_import_vendor() {
.unwrap()
.program;
assert_eq!(m.pkgs.len(), pkgs.len());
m.pkgs.into_iter().for_each(|(name, modules)| {
m.pkgs.clone().into_iter().for_each(|(name, modules)| {
println!("{:?} - {:?}", test_case_name, name);
assert!(pkgs.contains(&name.as_str()));
for pkg in pkgs.clone() {
if name == pkg {
if name == "__main__" {
assert_eq!(modules.len(), 1);
assert_eq!(modules.get(0).unwrap().filename, test_case_path);
let module = m.get_module(modules.get(0).unwrap()).unwrap().unwrap();
assert_eq!(module.filename, test_case_path);
} else {
modules.into_iter().for_each(|module| {
let module = m.get_module(&module).unwrap().unwrap();
assert!(module.filename.contains(&vendor));
});
}
Expand Down Expand Up @@ -354,15 +356,17 @@ pub fn test_import_vendor_without_kclmod() {
.unwrap()
.program;
assert_eq!(m.pkgs.len(), pkgs.len());
m.pkgs.into_iter().for_each(|(name, modules)| {
m.pkgs.clone().into_iter().for_each(|(name, modules)| {
assert!(pkgs.contains(&name.as_str()));
for pkg in pkgs.clone() {
if name == pkg {
if name == "__main__" {
assert_eq!(modules.len(), 1);
assert_eq!(modules.get(0).unwrap().filename, test_case_path);
let module = m.get_module(modules.get(0).unwrap()).unwrap().unwrap();
assert_eq!(module.filename, test_case_path);
} else {
modules.into_iter().for_each(|module| {
let module = m.get_module(&module).unwrap().unwrap();
assert!(module.filename.contains(&vendor));
});
}
Expand Down Expand Up @@ -594,15 +598,17 @@ fn test_import_vendor_by_external_arguments() {
.unwrap()
.program;
assert_eq!(m.pkgs.len(), pkgs.len());
m.pkgs.into_iter().for_each(|(name, modules)| {
m.pkgs.clone().into_iter().for_each(|(name, modules)| {
assert!(pkgs.contains(&name.as_str()));
for pkg in pkgs.clone() {
if name == pkg {
if name == "__main__" {
assert_eq!(modules.len(), 1);
assert_eq!(modules.get(0).unwrap().filename, test_case_path);
let module = m.get_module(modules.get(0).unwrap()).unwrap().unwrap();
assert_eq!(module.filename, test_case_path);
} else {
modules.into_iter().for_each(|module| {
let module = m.get_module(&module).unwrap().unwrap();
assert!(module.filename.contains(&vendor));
});
}
Expand Down
Loading

0 comments on commit c9ee296

Please sign in to comment.