From 8264ed4e8fea385da6a583585d61cf61d4bd9217 Mon Sep 17 00:00:00 2001 From: Michael Jurasovic Date: Thu, 5 Dec 2024 18:45:56 +1100 Subject: [PATCH 1/2] Add py.typed --- src/fastmcp/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/fastmcp/py.typed diff --git a/src/fastmcp/py.typed b/src/fastmcp/py.typed new file mode 100644 index 0000000..e69de29 From 165dd1f6beaaf61b317d5cbc84e19336d90dfa6c Mon Sep 17 00:00:00 2001 From: Michael Jurasovic Date: Thu, 5 Dec 2024 18:46:23 +1100 Subject: [PATCH 2/2] Improve decorator typing --- src/fastmcp/server.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/fastmcp/server.py b/src/fastmcp/server.py index d898853..8b9238d 100644 --- a/src/fastmcp/server.py +++ b/src/fastmcp/server.py @@ -6,7 +6,7 @@ import json import re from itertools import chain -from typing import Any, Callable, Dict, Literal, Sequence +from typing import Any, Callable, Dict, Literal, Sequence, TypeVar, ParamSpec import pydantic_core from pydantic import Field @@ -40,6 +40,7 @@ from fastmcp.exceptions import ResourceError from fastmcp.prompts import Prompt, PromptManager +from fastmcp.prompts.base import PromptResult from fastmcp.resources import FunctionResource, Resource, ResourceManager from fastmcp.tools import ToolManager from fastmcp.utilities.logging import configure_logging, get_logger @@ -47,6 +48,10 @@ logger = get_logger(__name__) +P = ParamSpec("P") +R = TypeVar("R") +R_PromptResult = TypeVar("R_PromptResult", bound=PromptResult) + class Settings(BaseSettings): """FastMCP server settings. @@ -222,7 +227,9 @@ def add_tool( """ self._tool_manager.add_tool(fn, name=name, description=description) - def tool(self, name: str | None = None, description: str | None = None) -> Callable: + def tool( + self, name: str | None = None, description: str | None = None + ) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator to register a tool. Tools can optionally request a Context object by adding a parameter with the Context type annotation. @@ -254,7 +261,7 @@ async def async_tool(x: int, context: Context) -> str: "Did you forget to call it? Use @tool() instead of @tool" ) - def decorator(fn: Callable) -> Callable: + def decorator(fn: Callable[P, R]) -> Callable[P, R]: self.add_tool(fn, name=name, description=description) return fn @@ -275,7 +282,7 @@ def resource( name: str | None = None, description: str | None = None, mime_type: str | None = None, - ) -> Callable: + ) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator to register a function as a resource. The function will be called when the resource is read to generate its content. @@ -309,9 +316,9 @@ def get_weather(city: str) -> str: "Did you forget to call it? Use @resource('uri') instead of @resource" ) - def decorator(fn: Callable) -> Callable: + def decorator(fn: Callable[P, R]) -> Callable[P, R]: @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return fn(*args, **kwargs) # Check if this should be a template @@ -361,7 +368,7 @@ def add_prompt(self, prompt: Prompt) -> None: def prompt( self, name: str | None = None, description: str | None = None - ) -> Callable: + ) -> Callable[[Callable[P, R_PromptResult]], Callable[P, R_PromptResult]]: """Decorator to register a prompt. Args: @@ -402,7 +409,7 @@ async def analyze_file(path: str) -> list[Message]: "Did you forget to call it? Use @prompt() instead of @prompt" ) - def decorator(func: Callable) -> Callable: + def decorator(func: Callable[P, R_PromptResult]) -> Callable[P, R_PromptResult]: prompt = Prompt.from_function(func, name=name, description=description) self.add_prompt(prompt) return func