Skip to content

Commit

Permalink
Use FastStream's native acknowledgements (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Feb 4, 2025
1 parent 0b66575 commit 82a56b7
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 58 deletions.
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def anyio_backend(request: pytest.FixtureRequest) -> object:
return request.param


@pytest.fixture
def first_server_connection_parameters() -> stompman.ConnectionParameters:
return stompman.ConnectionParameters(host="127.0.0.1", port=9000, login="admin", passcode=":=123")


@pytest.fixture(
params=[
stompman.ConnectionParameters(host="127.0.0.1", port=9000, login="admin", passcode=":=123"),
Expand Down
4 changes: 4 additions & 0 deletions packages/faststream-stomp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,7 @@ if __name__ == "__main__":
```

Also there are `StompRouter` and `TestStompBroker` for testing. It works similarly to built-in brokers from FastStream, I recommend to read the original [FastStream documentation](https://faststream.airt.ai/latest/getting-started).

### Caveats

- When exception is raised in consumer handler, the message will be nacked (FastStream doesn't do this by default)
2 changes: 2 additions & 0 deletions packages/faststream-stomp/faststream_stomp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from faststream_stomp.broker import StompBroker
from faststream_stomp.message import StompStreamMessage
from faststream_stomp.publisher import StompPublisher
from faststream_stomp.router import StompRoute, StompRoutePublisher, StompRouter
from faststream_stomp.subscriber import StompSubscriber
Expand All @@ -10,6 +11,7 @@
"StompRoute",
"StompRoutePublisher",
"StompRouter",
"StompStreamMessage",
"StompSubscriber",
"TestStompBroker",
]
33 changes: 33 additions & 0 deletions packages/faststream-stomp/faststream_stomp/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import typing
from typing import cast

import stompman
from faststream.broker.message import StreamMessage, gen_cor_id


class StompStreamMessage(StreamMessage[stompman.AckableMessageFrame]):
async def ack(self) -> None:
if not self.committed:
await self.raw_message.ack()
return await super().ack()

async def nack(self) -> None:
if not self.committed:
await self.raw_message.nack()
return await super().nack()

async def reject(self) -> None:
if not self.committed:
await self.raw_message.nack()
return await super().reject()

@classmethod
async def from_frame(cls, message: stompman.AckableMessageFrame) -> typing.Self:
return cls(
raw_message=message,
body=message.body,
headers=cast("dict[str, str]", message.headers),
content_type=message.headers.get("content-type"),
message_id=message.headers["message-id"],
correlation_id=cast("str", message.headers.get("correlation-id", gen_cor_id())),
)
20 changes: 0 additions & 20 deletions packages/faststream-stomp/faststream_stomp/parser.py

This file was deleted.

15 changes: 5 additions & 10 deletions packages/faststream-stomp/faststream_stomp/registrator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, cast

import stompman
Expand All @@ -11,9 +11,6 @@
from faststream_stomp.subscriber import StompSubscriber


def noop_handle_suppressed_exception(exception: Exception, message: stompman.MessageFrame) -> None: ...


class StompRegistrator(ABCBroker[stompman.MessageFrame]):
_subscribers: Mapping[int, StompSubscriber]
_publishers: Mapping[int, StompPublisher]
Expand All @@ -22,12 +19,11 @@ def subscriber( # type: ignore[override]
self,
destination: str,
*,
ack: stompman.AckMode = "client-individual",
ack_mode: stompman.AckMode = "client-individual",
headers: dict[str, str] | None = None,
on_suppressed_exception: Callable[[Exception, stompman.MessageFrame], Any] = noop_handle_suppressed_exception,
suppressed_exception_classes: tuple[type[Exception], ...] = (Exception,),
# other args
dependencies: Iterable[Depends] = (),
no_ack: bool = False,
parser: CustomCallable | None = None,
decoder: CustomCallable | None = None,
middlewares: Sequence[SubscriberMiddleware[stompman.MessageFrame]] = (),
Expand All @@ -41,11 +37,10 @@ def subscriber( # type: ignore[override]
super().subscriber(
StompSubscriber(
destination=destination,
ack=ack,
ack_mode=ack_mode,
headers=headers,
on_suppressed_exception=on_suppressed_exception,
suppressed_exception_classes=suppressed_exception_classes,
retry=retry,
no_ack=no_ack,
broker_middlewares=self._middlewares,
broker_dependencies=self._dependencies,
title_=title,
Expand Down
12 changes: 5 additions & 7 deletions packages/faststream-stomp/faststream_stomp/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from faststream.broker.types import BrokerMiddleware, CustomCallable, PublisherMiddleware, SubscriberMiddleware
from faststream.types import SendableMessage

from faststream_stomp.registrator import StompRegistrator, noop_handle_suppressed_exception
from faststream_stomp.registrator import StompRegistrator


class StompRoutePublisher(ArgsContainer):
Expand Down Expand Up @@ -47,13 +47,12 @@ def __init__(
call: Callable[..., SendableMessage] | Callable[..., Awaitable[SendableMessage]],
destination: str,
*,
ack: stompman.AckMode = "client-individual",
ack_mode: stompman.AckMode = "client-individual",
headers: dict[str, str] | None = None,
on_suppressed_exception: Callable[[Exception, stompman.MessageFrame], Any] = noop_handle_suppressed_exception,
suppressed_exception_classes: tuple[type[Exception], ...] = (Exception,),
# other args
publishers: Iterable[StompRoutePublisher] = (),
dependencies: Iterable[Depends] = (),
no_ack: bool = False,
parser: CustomCallable | None = None,
decoder: CustomCallable | None = None,
middlewares: Sequence[SubscriberMiddleware[stompman.MessageFrame]] = (),
Expand All @@ -65,12 +64,11 @@ def __init__(
super().__init__(
call=call,
destination=destination,
ack=ack,
ack_mode=ack_mode,
headers=headers,
on_suppressed_exception=on_suppressed_exception,
suppressed_exception_classes=suppressed_exception_classes,
publishers=publishers,
dependencies=dependencies,
no_ack=no_ack,
parser=parser,
decoder=decoder,
middlewares=middlewares,
Expand Down
30 changes: 13 additions & 17 deletions packages/faststream-stomp/faststream_stomp/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,42 @@
from fast_depends.dependencies import Depends
from faststream.asyncapi.schema import Channel, CorrelationId, Message, Operation
from faststream.asyncapi.utils import resolve_payloads
from faststream.broker.message import StreamMessage
from faststream.broker.message import StreamMessage, decode_message
from faststream.broker.publisher.fake import FakePublisher
from faststream.broker.publisher.proto import ProducerProto
from faststream.broker.subscriber.usecase import SubscriberUsecase
from faststream.broker.types import AsyncCallable, BrokerMiddleware, CustomCallable
from faststream.types import AnyDict, Decorator, LoggerProto
from faststream.utils.functions import to_async

from faststream_stomp import parser
from faststream_stomp.message import StompStreamMessage


class StompSubscriber(SubscriberUsecase[stompman.MessageFrame]):
def __init__(
self,
*,
destination: str,
ack: stompman.AckMode = "client-individual",
headers: dict[str, str] | None = None,
on_suppressed_exception: Callable[[Exception, stompman.MessageFrame], Any],
suppressed_exception_classes: tuple[type[Exception], ...] = (Exception,),
ack_mode: stompman.AckMode,
headers: dict[str, str] | None,
retry: bool | int,
no_ack: bool,
broker_dependencies: Iterable[Depends],
broker_middlewares: Sequence[BrokerMiddleware[stompman.MessageFrame]],
default_parser: AsyncCallable = parser.parse_message,
default_decoder: AsyncCallable = parser.decode_message,
default_parser: AsyncCallable = StompStreamMessage.from_frame,
default_decoder: AsyncCallable = to_async(decode_message), # noqa: B008
# AsyncAPI information
title_: str | None,
description_: str | None,
include_in_schema: bool,
) -> None:
self.destination = destination
self.ack = ack
self.ack_mode = ack_mode
self.headers = headers
self.on_suppressed_exception = on_suppressed_exception
self.suppressed_exception_classes = suppressed_exception_classes
self._subscription: stompman.AutoAckSubscription | None = None
self._subscription: stompman.ManualAckSubscription | None = None

super().__init__(
no_ack=self.ack == "auto",
no_ack=no_ack or self.ack_mode == "auto",
no_reply=True,
retry=retry,
broker_dependencies=broker_dependencies,
Expand Down Expand Up @@ -85,13 +83,11 @@ def setup( # type: ignore[override]

async def start(self) -> None:
await super().start()
self._subscription = await self.client.subscribe(
self._subscription = await self.client.subscribe_with_manual_ack(
destination=self.destination,
handler=self.consume,
ack=self.ack,
ack=self.ack_mode,
headers=self.headers,
on_suppressed_exception=self.on_suppressed_exception,
suppressed_exception_classes=self.suppressed_exception_classes,
)

async def close(self) -> None:
Expand Down
11 changes: 9 additions & 2 deletions packages/faststream-stomp/faststream_stomp/testing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import uuid
from typing import TYPE_CHECKING, Any
from unittest import mock
from unittest.mock import AsyncMock

import stompman
from faststream.broker.message import encode_message
from faststream.testing.broker import TestBroker
from faststream.types import SendableMessage
from stompman import MessageFrame

from faststream_stomp.broker import StompBroker
from faststream_stomp.publisher import StompProducer, StompPublisher
Expand Down Expand Up @@ -44,6 +45,12 @@ async def _fake_connect(
broker._producer = FakeStompProducer(broker) # noqa: SLF001


class FakeAckableMessageFrame(stompman.AckableMessageFrame):
async def ack(self) -> None: ...

async def nack(self) -> None: ...


class FakeStompProducer(StompProducer):
def __init__(self, broker: StompBroker) -> None:
self.broker = broker
Expand All @@ -66,7 +73,7 @@ async def publish( # type: ignore[override]
all_headers["correlation-id"] = correlation_id # type: ignore[typeddict-unknown-key]
if content_type:
all_headers["content-type"] = content_type
frame = MessageFrame(headers=all_headers, body=body)
frame = FakeAckableMessageFrame(headers=all_headers, body=body, _subscription=mock.AsyncMock())

for handler in self.broker._subscribers.values(): # noqa: SLF001
if handler.destination == destination:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
from typing import Annotated

import faker
import faststream_stomp
import pytest
import stompman
from faststream import BaseMiddleware, Context, FastStream
from faststream.broker.message import gen_cor_id
from faststream.exceptions import AckMessage, NackMessage, RejectMessage
from faststream_stomp.message import StompStreamMessage

pytestmark = pytest.mark.anyio


@pytest.fixture
def broker(connection_parameters: stompman.ConnectionParameters) -> faststream_stomp.StompBroker:
return faststream_stomp.StompBroker(stompman.Client([connection_parameters]))
def broker(first_server_connection_parameters: stompman.ConnectionParameters) -> faststream_stomp.StompBroker:
return faststream_stomp.StompBroker(stompman.Client([first_server_connection_parameters]))


async def test_simple(faker: faker.Faker, broker: faststream_stomp.StompBroker) -> None:
Expand Down Expand Up @@ -114,3 +117,37 @@ async def test_no_connection(self, broker: faststream_stomp.StompBroker) -> None
async def test_timeout(self, broker: faststream_stomp.StompBroker) -> None:
async with broker:
assert not await broker.ping(0)


@pytest.mark.parametrize("exception", [Exception, NackMessage, AckMessage, RejectMessage])
async def test_ack_nack_reject_exception(
faker: faker.Faker, broker: faststream_stomp.StompBroker, exception: type[Exception]
) -> None:
event = asyncio.Event()

@broker.subscriber(destination := faker.pystr())
def _() -> None:
event.set()
raise exception

async with broker:
await broker.start()
await broker.publish(faker.pystr(), destination)
await event.wait()


@pytest.mark.parametrize("method_name", ["ack", "nack", "reject"])
async def test_ack_nack_reject_method_call(
faker: faker.Faker, broker: faststream_stomp.StompBroker, method_name: str
) -> None:
event = asyncio.Event()

@broker.subscriber(destination := faker.pystr())
async def _(message: Annotated[StompStreamMessage, Context()]) -> None:
await getattr(message, method_name)()
event.set()

async with broker:
await broker.start()
await broker.publish(faker.pystr(), destination)
await event.wait()

0 comments on commit 82a56b7

Please sign in to comment.