Skip to content

Commit

Permalink
✨ recursive_guard
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 30, 2024
1 parent 5e3299c commit 589fb58
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 14 deletions.
24 changes: 15 additions & 9 deletions arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ def dispatch(*events: type[Event], predicate: Callable[[Event], bool] | None = N
return plugin.dispatch(*events, predicate=predicate)


_recrusive_guard = set()


def load_plugin(path: str) -> Plugin | None:
def load_plugin(path: str, recursive_guard: set[str] | None = None) -> Plugin | None:
"""
以导入路径方式加载模块
Args:
path (str): 模块路径
recursive_guard (set[str]): 递归保护
"""
if recursive_guard is None:
recursive_guard = set()
if path in service._submoded:
logger.error(f"plugin {path!r} is already defined as submodule of {service._submoded[path]!r}")
return
if path in service.plugins:
return service.plugins[path]
try:
Expand All @@ -45,15 +48,18 @@ def load_plugin(path: str) -> Plugin | None:
logger.success(f"loaded plugin {path!r}")
if mod.__name__ in service._unloaded:
if mod.__name__ in service._referents and service._referents[mod.__name__]:
for referent in service._referents[mod.__name__]:
if referent in _recrusive_guard:
referents = service._referents[mod.__name__].copy()
service._referents[mod.__name__].clear()
for referent in referents:
if referent in recursive_guard:
continue
_recrusive_guard.add(referent)
if referent in service.plugins:
logger.debug(f"reloading {mod.__name__}'s referent {referent!r}")
dispose(referent)
load_plugin(referent)
_recrusive_guard.clear()
if not load_plugin(referent):
service._referents[mod.__name__].add(referent)
else:
recursive_guard.add(referent)
service._unloaded.discard(mod.__name__)
return mod.__plugin__
except RegisterNotInPluginError as e:
Expand Down
3 changes: 3 additions & 0 deletions arclet/entari/plugin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ def __getattr__(self, name: str):
if plug := inspect.currentframe().f_back.f_globals.get("__plugin__"): # type: ignore
if plug.id != self.__plugin_id:
service._referents[self.__plugin_id].add(plug.id)
elif plug := inspect.currentframe().f_back.f_back.f_globals.get("__plugin__"): # type: ignore
if plug.id != self.__plugin_id:
service._referents[self.__plugin_id].add(plug.id)
return getattr(self.__get_module(), name)

def __setattr__(self, name: str, value):
Expand Down
8 changes: 4 additions & 4 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def exec_module(self, module: ModuleType) -> None:
raise
else:
plugin.submodules[module.__name__] = module
service._submoded[module.__name__] = plugin.id
return

if self.loaded:
Expand Down Expand Up @@ -196,10 +197,9 @@ def find_spec(
if module_spec.name in service.plugins:
module_spec.loader = PluginLoader(fullname, module_origin)
return module_spec
for plug in service.plugins.values():
if module_spec.name in plug.submodules:
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
return module_spec
if module_spec.name in service._submoded:
module_spec.loader = PluginLoader(fullname, module_origin, service._submoded[module_spec.name])
return module_spec
return


Expand Down
2 changes: 2 additions & 0 deletions arclet/entari/plugin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ class PluginService(Service):
_keep_values: dict[str, dict[str, "KeepingVariable"]]
_referents: dict[str, set[str]]
_unloaded: set[str]
_submoded: dict[str, str]

def __init__(self):
super().__init__()
self.plugins = {}
self._keep_values = {}
self._referents = {}
self._unloaded = set()
self._submoded = {}

@property
def required(self) -> set[str]:
Expand Down
2 changes: 1 addition & 1 deletion example_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def append(data: str, session: Session):
async def show(session: Session):
await session.send_message(f"Data: {kept_data}")

TEST = 5
TEST = 6

print([*Plugin.current().dispatchers.keys()])
print(Plugin.current().submodules)
Expand Down

0 comments on commit 589fb58

Please sign in to comment.