diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 64173266..136eb4d5 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -178,6 +178,13 @@ def signal(fn: CallableSyncOrAsyncReturnNoneType) -> CallableSyncOrAsyncReturnNo ... +@overload +def signal() -> ( + Callable[[CallableSyncOrAsyncReturnNoneType], CallableSyncOrAsyncReturnNoneType] +): + ... + + @overload def signal( *, name: str @@ -232,12 +239,10 @@ def with_name( ) return fn - if name is not None or dynamic: + if not fn: if name is not None and dynamic: raise RuntimeError("Cannot provide name and dynamic boolean") return partial(with_name, name) - if fn is None: - raise RuntimeError("Cannot create signal without function or name or dynamic") return with_name(fn.__name__, fn) @@ -246,6 +251,11 @@ def query(fn: CallableType) -> CallableType: ... +@overload +def query() -> Callable[[CallableType], CallableType]: + ... + + @overload def query(*, name: str) -> Callable[[CallableType], CallableType]: ... @@ -302,12 +312,10 @@ def with_name( ) return fn - if name is not None or dynamic: + if not fn: if name is not None and dynamic: raise RuntimeError("Cannot provide name and dynamic boolean") return partial(with_name, name) - if fn is None: - raise RuntimeError("Cannot create query without function or name or dynamic") if inspect.iscoroutinefunction(fn): warnings.warn( "Queries as async def functions are deprecated", @@ -921,6 +929,16 @@ def update( ... +@overload +def update() -> ( + Callable[ + [Callable[MultiParamSpec, ReturnType]], + UpdateMethodMultiParam[MultiParamSpec, ReturnType], + ] +): + ... + + @overload def update( *, name: str @@ -987,12 +1005,10 @@ def with_name( setattr(fn, "validator", partial(_update_validator, defn)) return fn - if name is not None or dynamic: + if not fn: if name is not None and dynamic: raise RuntimeError("Cannot provide name and dynamic boolean") return partial(with_name, name) - if fn is None: - raise RuntimeError("Cannot create update without function or name or dynamic") return with_name(fn.__name__, fn) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 510fa18d..b34593e2 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -42,6 +42,10 @@ def signal2(self): def signal3(self, name: str, args: Sequence[RawValue]): pass + @workflow.signal() + def signal4(self): + pass + @workflow.query def query1(self): pass @@ -54,6 +58,10 @@ def query2(self): def query3(self, name: str, args: Sequence[RawValue]): pass + @workflow.query() + def query4(self): + pass + @workflow.update def update1(self): pass @@ -66,6 +74,10 @@ def update2(self): def update3(self, name: str, args: Sequence[RawValue]): pass + @workflow.update() + def update4(self): + pass + def test_workflow_defn_good(): # Although the API is internal, we want to check the literal definition just