Skip to content

Commit

Permalink
✨ ProxyModule
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 30, 2024
1 parent 44f159c commit fddc228
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 58 deletions.
3 changes: 2 additions & 1 deletion arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def load_plugins(dir_: str | PathLike | Path):

def dispose(plugin: str):
if plugin not in service.plugins:
return
return False
_plugin = service.plugins[plugin]
_plugin.dispose()
return True


@init_spec(PluginMetadata)
Expand Down
21 changes: 21 additions & 0 deletions arclet/entari/plugin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def validate(self, func):
f"`package({func.__module__!r})` before import it."
)

@property
def proxy(self):
return _ProxyModule(self.id)


class KeepingVariable:
def __init__(self, obj: T, dispose: Callable[[T], None] | None = None):
Expand All @@ -213,3 +217,20 @@ def keeping(id_: str, obj: T, dispose: Callable[[T], None] | None = None) -> T:
else:
obj = service._keep_values[plug.id][id_].obj # type: ignore
return obj


class _ProxyModule:
def __init__(self, plugin_id: str) -> None:
self.__plugin_id = plugin_id

def __getattr__(self, name: str):
if self.__plugin_id not in service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
return getattr(service.plugins[self.__plugin_id].module, name)

def __setattr__(self, name: str, value):
if name == "_ProxyModule__plugin_id":
return super().__setattr__(name, value)
if self.__plugin_id not in service.plugins:
raise NameError(f"Plugin {self.__plugin_id!r} is not loaded")
setattr(service.plugins[self.__plugin_id].module, name, value)
114 changes: 61 additions & 53 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def _unpack_import_from(__fullname: str, mod: str, aliases: list[str]):
return tuple(getattr(_mod, alias) for alias in aliases)


def _check_import(name: str, plugin_name: str):
if name in service.plugins:
return service.plugins[name].proxy
if name in _SUBMODULE_WAITLIST.get(plugin_name, ()):
return import_plugin(name)
return __import__(name)


class PluginLoader(SourceFileLoader):
def __init__(self, fullname: str, path: str, parent_plugin_id: Optional[str] = None) -> None:
self.loaded = False
Expand All @@ -50,17 +58,19 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
",".join(aliases)
+ f"=__unpack_import_from('{self.name}', '', {[alias.name for alias in body.names]!r})"
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
if body.level == 1:
if body.module is None:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
",".join(aliases)
+ f"=__unpack_import_from('{self.name}', '.', {[alias.name for alias in body.names]!r})"
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
else:
aliases = [alias.asname or alias.name for alias in body.names]
nodes.body[i] = ast.parse(
Expand All @@ -70,8 +80,9 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
f"{[alias.name for alias in body.names]!r})"
)
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
elif (
isinstance(body, ast.Expr)
and isinstance(body.value, ast.Call)
Expand All @@ -85,17 +96,11 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore
nodes.body[i] = ast.parse(
",".join(aliases)
+ "="
+ ",".join(
(
f"__import_plugin({alias.name!r})"
if (alias.name in _SUBMODULE_WAITLIST.get(self.name, ()) or alias.name in service.plugins)
else f"__import__({alias.name!r})"
)
for alias in body.names
)
+ ",".join((f"__check_import({alias.name!r}, {self.name!r})") for alias in body.names)
).body[0]
nodes.body[i].lineno = body.lineno
nodes.body[i].end_lineno = body.end_lineno
for node in ast.walk(nodes.body[i]):
node.lineno = body.lineno # type: ignore
node.end_lineno = body.end_lineno # type: ignore
return _bootstrap._call_with_frames_removed(compile, nodes, path, "exec", dont_inherit=True, optimize=_optimize)

def create_module(self, spec) -> Optional[ModuleType]:
Expand All @@ -105,12 +110,12 @@ def create_module(self, spec) -> Optional[ModuleType]:
return super().create_module(spec)

def exec_module(self, module: ModuleType) -> None:
if plugin := _current_plugin.get(service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None):
if plugin := service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None:
if module.__name__ == plugin.module.__name__: # from . import xxxx
return
setattr(module, "__plugin__", plugin)
setattr(module, "__unpack_import_from", _unpack_import_from)
setattr(module, "__import_plugin", import_plugin)
setattr(module, "__check_import", _check_import)
try:
super().exec_module(module)
except Exception:
Expand All @@ -127,7 +132,7 @@ def exec_module(self, module: ModuleType) -> None:
plugin = Plugin(module.__name__, module)
setattr(module, "__plugin__", plugin)
setattr(module, "__unpack_import_from", _unpack_import_from)
setattr(module, "__import_plugin", import_plugin)
setattr(module, "__check_import", _check_import)

# enter plugin context
_plugin_token = _current_plugin.set(plugin)
Expand All @@ -148,6 +153,41 @@ def exec_module(self, module: ModuleType) -> None:
return


class _PluginFinder(MetaPathFinder):
@classmethod
def find_spec(
cls,
fullname: str,
path: Optional[Sequence[str]],
target: Optional[ModuleType] = None,
):
module_spec = PathFinder.find_spec(fullname, path, target)
if not module_spec:
return
module_origin = module_spec.origin
if not module_origin:
return
if plug := _current_plugin.get(None):
if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin:
return plug.module.__spec__
if module_spec.parent and module_spec.parent == plug.module.__name__:
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
return module_spec
elif module_spec.name in _SUBMODULE_WAITLIST.get(plug.module.__name__, ()):
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
# _SUBMODULE_WAITLIST[plug.module.__name__].remove(module_spec.name)
return module_spec

if module_spec.name in service.plugins:
module_spec.loader = PluginLoader(fullname, module_origin)
return module_spec
for plug in service.plugins.values():
if module_spec.name in plug.submodules:
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
return module_spec
return


def find_spec(name, package=None):
fullname = resolve_name(name, package) if name.startswith(".") else name
parent_name = fullname.rpartition(".")[0]
Expand All @@ -165,6 +205,8 @@ def find_spec(name, package=None):
) from e
else:
parent_path = None
if spec := _PluginFinder.find_spec(fullname, parent_path):
return spec
module_spec = PathFinder.find_spec(fullname, parent_path, None)
if not module_spec:
return
Expand All @@ -185,38 +227,4 @@ def import_plugin(name, package=None):
return


class _PluginFinder(MetaPathFinder):
def find_spec(
self,
fullname: str,
path: Optional[Sequence[str]],
target: Optional[ModuleType] = None,
):
module_spec = PathFinder.find_spec(fullname, path, target)
if not module_spec:
return
module_origin = module_spec.origin
if not module_origin:
return
if plug := _current_plugin.get(None):
if plug.module.__spec__ and plug.module.__spec__.origin == module_spec.origin:
return plug.module.__spec__
if module_spec.parent and module_spec.parent == plug.module.__name__:
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
return module_spec
elif module_spec.name in _SUBMODULE_WAITLIST[plug.module.__name__]:
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
# _SUBMODULE_WAITLIST[plug.module.__name__].remove(module_spec.name)
return module_spec

if module_spec.name in service.plugins:
module_spec.loader = PluginLoader(fullname, module_origin)
return module_spec
for plug in service.plugins.values():
if module_spec.name in plug.submodules:
module_spec.loader = PluginLoader(fullname, module_origin, plug.id)
return module_spec
return


sys.meta_path.insert(0, _PluginFinder())
7 changes: 7 additions & 0 deletions example_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import sys

from arclet.alconna import Alconna, AllParam, Args

Expand Down Expand Up @@ -85,3 +86,9 @@ async def append(data: str, session: Session):
@command.on("show")
async def show(session: Session):
await session.send_message(f"Data: {kept_data}")

TEST = 2

print([*Plugin.current().dispatchers.keys()])
print(Plugin.current().submodules)
print("example_plugin not in sys.modules (expect True):", "example_plugin" not in sys.modules)
12 changes: 8 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ async def echoimg(img: Image, session: Session):

@command.on("load {plugin}")
async def load(plugin: str, session: Session):
load_plugin(plugin)
await session.send_message(f"Loaded {plugin}")
if load_plugin(plugin):
await session.send_message(f"Loaded {plugin}")
else:
await session.send_message(f"Failed to load {plugin}")


@command.on("unload {plugin}")
async def unload(plugin: str, session: Session):
dispose_plugin(plugin)
await session.send_message(f"Unloaded {plugin}")
if dispose_plugin(plugin):
await session.send_message(f"Unloaded {plugin}")
else:
await session.send_message(f"Failed to unload {plugin}")

app.run()

0 comments on commit fddc228

Please sign in to comment.