Skip to content

Commit

Permalink
feat: RealyQueryType now inherits from RelayObjectType to give query …
Browse files Browse the repository at this point in the history
…the option to use connections, add unit tests
  • Loading branch information
pkucmus committed Jan 29, 2025
1 parent 5802494 commit c4ffc4b
Show file tree
Hide file tree
Showing 7 changed files with 627 additions and 53 deletions.
14 changes: 13 additions & 1 deletion ariadne/contrib/relay/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,19 @@ def __init__(self, *, last: int | None = None, before: str | None = None) -> Non
self.before = before


class ConnectionArguments(ForwardConnectionArguments, BackwardConnectionArguments): ...
class ConnectionArguments:
def __init__(
self,
*,
first: int | None = None,
after: str | None = None,
last: int | None = None,
before: str | None = None,
) -> None:
self.first = first
self.after = after
self.last = last
self.before = before


ConnectionArgumentsUnion = TypeAliasType(
Expand Down
102 changes: 50 additions & 52 deletions ariadne/contrib/relay/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from graphql.pyutils import is_awaitable

from ariadne import InterfaceType, ObjectType, QueryType
from ariadne import InterfaceType, ObjectType
from ariadne.contrib.relay.arguments import (
ConnectionArguments,
ConnectionArgumentsTypeUnion,
Expand All @@ -21,6 +21,53 @@ def decode_global_id(kwargs) -> GlobalIDTuple:
return GlobalIDTuple(*b64decode(kwargs["id"]).decode().split(":"))


class RelayObjectType(ObjectType):
def __init__(
self,
name: str,
connection_arguments_class: ConnectionArgumentsTypeUnion = ConnectionArguments,
) -> None:
super().__init__(name)
self.connection_arguments_class = connection_arguments_class

def resolve_wrapper(self, resolver: ConnectionResolver):
def wrapper(obj, info, *args, **kwargs):
connection_arguments = self.connection_arguments_class(**kwargs)
if iscoroutinefunction(resolver):

async def async_my_extension():
relay_connection = await resolver(
obj, info, connection_arguments, *args, **kwargs
)
if is_awaitable(relay_connection):
relay_connection = await relay_connection
return {
"edges": relay_connection.get_edges(),
"pageInfo": relay_connection.get_page_info(
connection_arguments
),
}

return async_my_extension()

relay_connection = resolver(
obj, info, connection_arguments, *args, **kwargs
)
return {
"edges": relay_connection.get_edges(),
"pageInfo": relay_connection.get_page_info(connection_arguments),
}

return wrapper

def connection(self, name: str):
def decorator(resolver: ConnectionResolver) -> ConnectionResolver:
self.set_field(name, self.resolve_wrapper(resolver))
return resolver

return decorator


class RelayNodeInterfaceType(InterfaceType):
def __init__(
self,
Expand Down Expand Up @@ -48,15 +95,13 @@ def get_node_resolver(self, type_name: str):
raise ValueError(f"No object resolver for type {type_name}") from exc


class RelayQueryType(QueryType):
class RelayQueryType(RelayObjectType):
def __init__(
self,
*args,
node=None,
node_field_resolver=None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
super().__init__("Query")
if not node:
node = RelayNodeInterfaceType()
self.node = node
Expand All @@ -81,50 +126,3 @@ async def async_my_extension():

return async_my_extension()
return resolver(obj, info, *args, **kwargs)


class RelayObjectType(ObjectType):
def __init__(
self,
name: str,
connection_arguments_class: ConnectionArgumentsTypeUnion = ConnectionArguments,
) -> None:
super().__init__(name)
self.connection_arguments_class = connection_arguments_class

def resolve_wrapper(self, resolver: ConnectionResolver):
def wrapper(obj, info, *args, **kwargs):
connection_arguments = self.connection_arguments_class(**kwargs)
if iscoroutinefunction(resolver):

async def async_my_extension():
relay_connection = await resolver(
obj, info, connection_arguments, *args, **kwargs
)
if is_awaitable(relay_connection):
relay_connection = await relay_connection
return {
"edges": relay_connection.get_edges(),
"pageInfo": relay_connection.get_page_info(
connection_arguments
),
}

return async_my_extension()

relay_connection = resolver(
obj, info, connection_arguments, *args, **kwargs
)
return {
"edges": relay_connection.get_edges(),
"pageInfo": relay_connection.get_page_info(connection_arguments),
}

return wrapper

def connection(self, name: str):
def decorator(resolver: ConnectionResolver) -> ConnectionResolver:
self.set_field(name, self.resolve_wrapper(resolver))
return resolver

return decorator
Empty file added tests/relay/__init__.py
Empty file.
192 changes: 192 additions & 0 deletions tests/relay/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from base64 import b64decode

import pytest
from ariadne.contrib.relay import (
ConnectionArguments,
GlobalIDTuple,
RelayConnection,
RelayNodeInterfaceType,
RelayQueryType,
)


@pytest.fixture
def relay_type_defs():
return """\
interface Node {
bid: ID!
}
type Faction implements Node {
bid: ID!
name: String
ships(first: Int!, after: ID): ShipConnection
}
type Ship implements Node {
bid: ID!
name: String
}
type ShipConnection {
edges: [ShipEdge]
pageInfo: PageInfo!
ships: [Ship]
totalCount: Int
}
type ShipEdge {
cursor: String!
node: Ship
}
type PageInfo {
hasNextPage: Boolean!
hasPreviousPage: Boolean!
startCursor: String
endCursor: String
}
type Query {
rebels: Faction
empire: Faction
node(bid: ID!): Node
}
"""


@pytest.fixture
def global_id_decoder():
return lambda kwargs: GlobalIDTuple(*b64decode(kwargs["bid"]).decode().split(":"))


@pytest.fixture
def relay_node_interface(global_id_decoder):
return RelayNodeInterfaceType(global_id_decoder=global_id_decoder)


@pytest.fixture
def relay_query(factions, relay_node_interface):
query = RelayQueryType(
node=relay_node_interface,
)
query.set_field("rebels", lambda *_: factions[0])
query.set_field("empire", lambda *_: factions[1])
query.node.set_field("bid", lambda obj, *_: obj["id"])
return query


@pytest.fixture
def ships():
return [
{
"id": "U2hpcDox",
"name": "X-Wing",
"factionId": "RmFjdGlvbjox",
},
{
"id": "U2hpcDoy",
"name": "Y-Wing",
"factionId": "RmFjdGlvbjox",
},
{
"id": "U2hpcDoz",
"name": "A-Wing",
"factionId": "RmFjdGlvbjox",
},
{
"id": "U2hpcDo0",
"name": "Millennium Falcon",
"factionId": "RmFjdGlvbjox",
},
{
"id": "U2hpcDo1",
"name": "Home One",
"factionId": "RmFjdGlvbjox",
},
{
"id": "U2hpcDo2",
"name": "TIE Fighter",
"factionId": "RmFjdGlvbjoy",
},
{
"id": "U2hpcDo3",
"name": "TIE Bomber",
"factionId": "RmFjdGlvbjoy",
},
{
"id": "U2hpcDo4",
"name": "TIE Interceptor",
"factionId": "RmFjdGlvbjoy",
},
{
"id": "U2hpcDo5",
"name": "Darth Vader's TIE Advanced",
"factionId": "RmFjdGlvbjoy",
},
]


@pytest.fixture
def factions():
return [
{
"id": "RmFjdGlvbjox",
"name": "Alliance to Restore the Republic",
},
{"id": "RmFjdGlvbjoy", "name": "Galactic Empire"},
]


@pytest.fixture
def relay_query_with_node_resolvers(relay_query, ships, factions):
relay_query.node.node_resolver("Faction")(
lambda *_, bid: [
{"__typename": "Faction", **faction}
for faction in factions
if faction["id"] == bid
][0]
)
relay_query.node.node_resolver("Ship")(
lambda *_, bid: [
{"__typename": "Ship", **ship} for ship in ships if ship["id"] == bid
][0]
)
return relay_query


@pytest.fixture
def ship_slice_resolver(ships):
# pylint: disable=unused-argument
def resolver(
faction_obj, info, connection_arguments: ConnectionArguments, **kwargs
):
faction_ships = [
ship for ship in ships if ship["factionId"] == faction_obj["id"]
]
total = len(faction_ships)
if connection_arguments.after:
after_index = (
faction_ships.index(
next(
ship
for ship in faction_ships
if ship["id"] == connection_arguments.after
)
)
+ 1
)
else:
after_index = 0
ships_slice = faction_ships[
after_index : after_index + connection_arguments.first
]

return RelayConnection(
edges=ships_slice,
total=total,
has_next_page=after_index + connection_arguments.first < total,
has_previous_page=after_index > 0,
)

return resolver
27 changes: 27 additions & 0 deletions tests/relay/test_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ariadne.contrib.relay.arguments import (
BackwardConnectionArguments,
ConnectionArguments,
ForwardConnectionArguments,
)


def test_connection_arguments():
connection_arguments = ConnectionArguments(
first=10, after="cursor", last=5, before="cursor"
)
assert connection_arguments.first == 10
assert connection_arguments.after == "cursor"
assert connection_arguments.last == 5
assert connection_arguments.before == "cursor"


def test_forward_connection_arguments():
connection_arguments = ForwardConnectionArguments(first=10, after="cursor")
assert connection_arguments.first == 10
assert connection_arguments.after == "cursor"


def test_backward_connection_arguments():
connection_arguments = BackwardConnectionArguments(last=5, before="cursor")
assert connection_arguments.last == 5
assert connection_arguments.before == "cursor"
Loading

0 comments on commit c4ffc4b

Please sign in to comment.