diff --git a/README.md b/README.md index 2284a23..31e7227 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,8 @@ If you want some new or better algorithms integrated just make a PR # TODO +- add tests and documentation for passthrough +- add tests and documentation for gas - stop when an open path regex is used. May append an invalid char and check if it is still ignoring - keep an eye on the performance impact of the new path regex checking - add tests for auto_snakecase and camelcase_path diff --git a/graphene_protector/base.py b/graphene_protector/base.py index 5ed0bc4..fe0a589 100644 --- a/graphene_protector/base.py +++ b/graphene_protector/base.py @@ -1,7 +1,8 @@ import re +from collections.abc import Callable from dataclasses import fields, replace from functools import wraps -from typing import Callable, List, Tuple +from typing import List, Tuple, Union from graphql import GraphQLInterfaceType, GraphQLObjectType, GraphQLUnionType from graphql.error import GraphQLError @@ -24,8 +25,10 @@ ComplexityLimitReached, DepthLimitReached, EarlyStop, + GasLimitReached, Limits, SelectionsLimitReached, + UsagesResult, default_path_ignore_pattern, ) @@ -45,9 +48,7 @@ def to_camel_case(snake_str): components = snake_str.split("_") # We capitalize the first letter of each component except the first one # with the 'capitalize' method and join them together. - return components[0] + "".join( - x.capitalize() if x else "_" for x in components[1:] - ) + return components[0] + "".join(x.capitalize() if x else "_" for x in components[1:]) # From this response in Stackoverflow @@ -68,17 +69,31 @@ def merge_limits(old_limits: Limits, new_limits: Limits): return replace(old_limits, **_limits) -def _extract_limits(scheme_field) -> Limits: +def _extract_limits(schema_field) -> Limits: while True: - if hasattr(scheme_field, "_graphene_protector_limits"): - return getattr(scheme_field, "_graphene_protector_limits") - if hasattr(scheme_field, "__func__"): - scheme_field = getattr(scheme_field, "__func__") + if hasattr(schema_field, "_graphene_protector_limits"): + return getattr(schema_field, "_graphene_protector_limits") + if hasattr(schema_field, "__func__"): + schema_field = getattr(schema_field, "__func__") else: break return MISSING_LIMITS +def gas_for_field(schema_field, **kwargs) -> int: + while True: + if hasattr(schema_field, "_graphene_protector_gas"): + retval = getattr(schema_field, "_graphene_protector_gas") + if callable(retval): + retval = retval(**kwargs) + return retval + if hasattr(schema_field, "__func__"): + schema_field = getattr(schema_field, "__func__") + else: + break + return 0 + + def limits_for_field(field, old_limits, **kwargs) -> Tuple[Limits, Limits]: # retrieve optional limitation attributes defined for the current # operation @@ -98,21 +113,21 @@ def check_resource_usage( camelcase_path=True, path_ignore_pattern: re.Pattern = _default_path_ignore_pattern, get_limits_for_field=limits_for_field, + get_gas_for_field=gas_for_field, level_depth=0, level_complexity=0, _seen_limits=None, _path="", -) -> Tuple[int, int, int]: +) -> UsagesResult: if _seen_limits is None: _seen_limits = set() # level 0: starts on query level. Every query is level 1 - selections = 0 - max_level_depth = level_depth - max_level_complexity = level_complexity - assert ( - limits.depth is not MISSING - ), "missing should be already resolved here" - if limits.depth and max_level_depth > limits.depth: + retval = UsagesResult( + max_level_depth=level_depth, + max_level_complexity=level_complexity, + ) + assert limits.depth is not MISSING, "missing should be already resolved here" + if limits.depth and retval.max_level_depth > limits.depth: on_error( DepthLimitReached( "Query is too deep", @@ -133,7 +148,8 @@ def check_resource_usage( if isinstance(field, (GraphQLUnionType, GraphQLInterfaceType)): merged_limits = limits - local_selections = 0 + local_union_selections = 0 + local_gas = 0 field_contributes_to_score = True _npath = "{}/{}".format( @@ -142,14 +158,8 @@ def check_resource_usage( ) if path_ignore_pattern.match(_npath): field_contributes_to_score = False - for field_type in validation_context.schema.get_possible_types( - field - ): - ( - new_depth, - new_depth_complexity, - local2_selections, - ) = check_resource_usage( + for field_type in validation_context.schema.get_possible_types(field): + local_result = check_resource_usage( follow_of_type(field_type), field, validation_context, @@ -159,6 +169,7 @@ def check_resource_usage( camelcase_path=camelcase_path, path_ignore_pattern=path_ignore_pattern, get_limits_for_field=get_limits_for_field, + get_gas_for_field=get_gas_for_field, level_depth=level_depth + 1 if field_contributes_to_score else level_depth, @@ -173,23 +184,28 @@ def check_resource_usage( # called per query, selection if ( merged_limits.complexity - and (new_depth_complexity - level_complexity) - * local2_selections + and (local_result.max_level_complexity - level_complexity) + * local_result.selections > merged_limits.complexity ): - on_error( - ComplexityLimitReached("Query is too complex", node) - ) + on_error(ComplexityLimitReached("Query is too complex", node)) + # find max of selections for unions + if local_result.selections > local_union_selections: + local_union_selections = local_result.selections # find max of selections for unions - if local2_selections > local_selections: - local_selections = local2_selections - if new_depth > max_level_depth: - max_level_depth = new_depth - if new_depth_complexity > max_level_complexity: - max_level_complexity = new_depth_complexity + if local_result.gas_used > local_gas: + local_gas = local_result.gas_used + if local_result.max_level_depth > retval.max_level_depth: + retval.max_level_depth = local_result.max_level_depth + if local_result.max_level_complexity > retval.max_level_complexity: + retval.max_level_complexity = local_result.max_level_complexity # ignore union fields itself for selection_count # because we have depth for that - selections += local_selections + retval.selections += local_union_selections + # gas for field itself already calculated in field.selection_set + retval.gas_used += local_gas + del local_union_selections + del local_gas elif field.selection_set: try: schema_field = getattr(schema, fieldname) @@ -213,6 +229,12 @@ def check_resource_usage( parent=schema, fieldname=fieldname, ) + # add gas for selection field + retval.gas_used += get_gas_for_field( + schema_field, + parent=schema, + fieldname=fieldname, + ) allow_reset = True field_contributes_to_score = True _npath = "{}/{}".format( @@ -234,11 +256,7 @@ def check_resource_usage( sub_field_type = schema_field else: sub_field_type = follow_of_type(schema_field.type) - ( - new_depth, - new_depth_complexity, - local_selections, - ) = check_resource_usage( + local_result = check_resource_usage( sub_field_type, field, validation_context, @@ -248,6 +266,7 @@ def check_resource_usage( camelcase_path=camelcase_path, path_ignore_pattern=path_ignore_pattern, get_limits_for_field=get_limits_for_field, + get_gas_for_field=get_gas_for_field, # field_contributes_to_score will be casted to 1 for True level_depth=level_depth + field_contributes_to_score if sub_limits.depth is MISSING or not allow_reset @@ -261,33 +280,71 @@ def check_resource_usage( # called per query, selection if ( merged_limits.complexity - and (new_depth - level_depth) * local_selections + and (local_result.max_level_depth - level_depth) + * local_result.selections > merged_limits.complexity ): on_error(ComplexityLimitReached("Query is too complex", node)) # increase level counter only if limits are not redefined - if sub_limits.depth is MISSING and new_depth > max_level_depth: - max_level_depth = new_depth + if ( + sub_limits.depth is MISSING or "depth" in sub_limits.passthrough + ) and local_result.max_level_depth > retval.max_level_depth: + retval.max_level_depth = local_result.max_level_depth if ( sub_limits.complexity is MISSING - and new_depth_complexity > max_level_complexity - ): - max_level_complexity = new_depth_complexity + or "complexity" in sub_limits.passthrough + ) and local_result.max_level_complexity > retval.max_level_complexity: + retval.max_level_complexity = local_result.max_level_complexity # ignore fields with selection_set itself for selection_count # because we have depth for that - if sub_limits.selections is MISSING: - selections += local_selections + if ( + sub_limits.selections is MISSING + or "selections" in sub_limits.passthrough + ): + retval.selections += local_result.selections + if sub_limits.gas is MISSING or "gas" in sub_limits.passthrough: + retval.gas_used += local_result.gas_used + del schema_field else: - field_contributes_to_score = True - if path_ignore_pattern.match(_path): - field_contributes_to_score = False - # field_contributes_to_score will be casted to 1 for True - selections += field_contributes_to_score + try: + schema_field = getattr(schema, fieldname) + except AttributeError: + _name = None + if hasattr(field, "name"): + _name = field.name + if hasattr(_name, "value"): + _name = _name.value + if ( + hasattr(schema, "fields") + and not isinstance(schema, GraphQLInterfaceType) + and _name + ): + schema_field = schema.fields[_name] + else: + schema_field = schema + retval.gas_used += get_gas_for_field( + schema_field, + parent=schema, + fieldname=fieldname, + ) + if not path_ignore_pattern.match(_path): + # field_contributes_to_score + retval.selections += 1 - if limits.selections and selections > limits.selections: + if limits.selections and retval.selections > limits.selections: on_error(SelectionsLimitReached("Query selects too much", node)) - return max_level_depth, max_level_complexity, selections + if limits.gas and retval.gas_used > limits.gas: + on_error(GasLimitReached("Query uses too much gas", node)) + return retval + + +def gas_usage(gas_used: Union[Callable[[], int], int]): + def wrapper(schema_field): + setattr(schema_field, "_graphene_protector_gas", gas_used) + return schema_field + + return wrapper class LimitsValidationRule(ValidationRule): @@ -351,6 +408,7 @@ def enter(self, node, key, parent, path, ancestors): maintype = schema.get_type(operation_type) assert maintype is not None get_limits_for_field = limits_for_field + get_gas_for_field = gas_for_field if hasattr(maintype, "graphene_type"): maintype = maintype.graphene_type if hasattr(schema, "_strawberry_schema"): @@ -360,9 +418,7 @@ def get_limits_for_field( ): name = follow_of_type(parent).name definition = ( - schema._strawberry_schema.schema_converter.type_map[ - name - ] + schema._strawberry_schema.schema_converter.type_map[name] ).definition # e.g. union if not hasattr(definition, "get_field"): @@ -371,6 +427,18 @@ def get_limits_for_field( nfield = definition.get_field(fieldname) return limits_for_field(nfield, old_limits) + def get_gas_for_field(field, parent, fieldname, **kwargs): + name = follow_of_type(parent).name + definition = ( + schema._strawberry_schema.schema_converter.type_map[name] + ).definition + # e.g. union + if not hasattr(definition, "get_field"): + return gas_for_field(definition) + + nfield = definition.get_field(fieldname) + return gas_for_field(nfield) + if getattr(self, "protector_on", True): try: check_resource_usage( @@ -380,6 +448,7 @@ def get_limits_for_field( self.default_limits, self.report_error, get_limits_for_field=get_limits_for_field, + get_gas_for_field=get_gas_for_field, auto_snakecase=self.auto_snakecase, camelcase_path=self.camelcase_path, path_ignore_pattern=self.path_ignore_pattern, @@ -461,9 +530,7 @@ class SchemaMixin: protector_default_limits = None protector_path_ignore_pattern = default_path_ignore_pattern - def __init_subclass__( - cls, protector_per_operation_validation=True, **kwargs - ): + def __init_subclass__(cls, protector_per_operation_validation=True, **kwargs): if hasattr(cls, "execute_sync"): cls.execute_sync = decorate_limits( cls.execute_sync, protector_per_operation_validation diff --git a/graphene_protector/django/base.py b/graphene_protector/django/base.py index 8338c4f..179b838 100644 --- a/graphene_protector/django/base.py +++ b/graphene_protector/django/base.py @@ -24,6 +24,9 @@ def get_protector_default_limits(self): complexity=_get_default_limit_from_settings( "GRAPHENE_PROTECTOR_COMPLEXITY_LIMIT" ), + gas=_get_default_limit_from_settings( + "GRAPHENE_PROTECTOR_GAS_LIMIT" + ), ), ), self.protector_default_limits, diff --git a/graphene_protector/misc.py b/graphene_protector/misc.py index 0c9df20..8600a5c 100644 --- a/graphene_protector/misc.py +++ b/graphene_protector/misc.py @@ -1,6 +1,6 @@ import sys -from dataclasses import dataclass -from typing import Union +from dataclasses import dataclass, field +from typing import Set, Union from graphql.error import GraphQLError @@ -14,6 +14,7 @@ class MISSING: _deco_options = {} if sys.version_info >= (3, 10): + _deco_options["kw_only"] = True _deco_options["slots"] = True if sys.version_info >= (3, 11): @@ -25,14 +26,26 @@ class Limits: depth: Union[int, None, MISSING] = MISSING selections: Union[int, None, MISSING] = MISSING complexity: Union[int, None, MISSING] = MISSING + gas: Union[int, None, MISSING] = MISSING + # only for sublimits not for main Limit instance + # passthrough for not missing limits + passthrough: Set[str] = field(default_factory=set) def __call__(self, field): setattr(field, "_graphene_protector_limits", self) return field +@dataclass(**_deco_options) +class UsagesResult: + max_level_depth: int = 0 + max_level_complexity: int = 0 + selections: int = 0 + gas_used: int = 0 + + MISSING_LIMITS = Limits() -DEFAULT_LIMITS = Limits(depth=20, selections=None, complexity=100) +DEFAULT_LIMITS = Limits(depth=20, selections=None, complexity=100, gas=None) class EarlyStop(Exception): @@ -51,6 +64,10 @@ class SelectionsLimitReached(ResourceLimitReached): pass +class GasLimitReached(ResourceLimitReached): + pass + + class ComplexityLimitReached(ResourceLimitReached): pass diff --git a/tests/test_graphql_core.py b/tests/test_graphql_core.py index 4eebf91..5c1981f 100644 --- a/tests/test_graphql_core.py +++ b/tests/test_graphql_core.py @@ -2,17 +2,17 @@ import unittest -from graphql import validate, parse +from graphql import parse, validate from graphql.type import GraphQLSchema -from graphene_protector import Limits, SchemaMixin, ValidationRule +from graphene_protector import Limits, SchemaMixin, ValidationRule from .graphql.schema import Query class Schema(GraphQLSchema, SchemaMixin): protector_default_limits = Limits( - depth=2, selections=None, complexity=None + depth=2, selections=None, complexity=None, gas=None ) auto_camelcase = False diff --git a/tests/testgraphene_base.py b/tests/testgraphene_base.py index 268add0..d714383 100644 --- a/tests/testgraphene_base.py +++ b/tests/testgraphene_base.py @@ -15,7 +15,7 @@ class TestGraphene(unittest.TestCase): def test_simple(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=2, selections=None, complexity=None), + limits=Limits(depth=2, selections=None, complexity=None, gas=None), types=[SomeNode], ) self.assertIsInstance(schema, GrapheneSchema) @@ -26,7 +26,7 @@ def test_simple(self): def test_node(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=2, selections=None, complexity=None), + limits=Limits(depth=2, selections=None, complexity=None, gas=None), types=[SomeNode], ) self.assertIsInstance(schema, GrapheneSchema) @@ -67,16 +67,14 @@ def test_success_connection(self): self.assertFalse(result.errors) self.assertEqual(len(result.data["someNodes"]["edges"]), 100) self.assertEqual( - from_global_id( - result.data["someNodes"]["edges"][99]["node"]["id"] - )[1], + from_global_id(result.data["someNodes"]["edges"][99]["node"]["id"])[1], "id-199", ) def test_error_connection(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=3, selections=None, complexity=None), + limits=Limits(depth=3, selections=None, complexity=None, gas=None), ) self.assertIsInstance(schema, GrapheneSchema) with self.subTest("success"): diff --git a/tests/testgraphene_camelcase.py b/tests/testgraphene_camelcase.py index b9883c9..0b43e6a 100644 --- a/tests/testgraphene_camelcase.py +++ b/tests/testgraphene_camelcase.py @@ -1,6 +1,7 @@ __package__ = "tests" import unittest + import graphene from graphene_protector import Limits @@ -33,7 +34,6 @@ def resolve_child_a(self, info): class Query(graphene.ObjectType): - set_directly = Limits(depth=2)(graphene.Field(Person)) unset_directly = Limits(depth=None)(graphene.Field(Person)) unset_hierachy = Limits(depth=1)(graphene.Field(Person3)) @@ -51,7 +51,7 @@ def resolve_unset_hierachy(root, info): schema = ProtectorSchema( query=Query, auto_camelcase=True, - limits=Limits(depth=3, selections=None, complexity=None), + limits=Limits(depth=3, selections=None, complexity=None, gas=None), ) diff --git a/tests/testgraphene_field.py b/tests/testgraphene_field.py index 763c819..6f6c7d1 100644 --- a/tests/testgraphene_field.py +++ b/tests/testgraphene_field.py @@ -1,10 +1,12 @@ __package__ = "tests" import unittest + import graphene from graphene_protector import Limits from graphene_protector.graphene import Schema as ProtectorSchema + from .graphene_base import Person @@ -35,7 +37,6 @@ def resolve_child(self, info): class Query(graphene.ObjectType): - setDirectly = Limits(depth=2)(graphene.Field(Person)) unsetDirectly = Limits(depth=None)(graphene.Field(Person)) unsetHierachy = Limits(depth=1)(graphene.Field(Person3)) @@ -55,7 +56,7 @@ def resolve_setHierachy(root, info): schema = ProtectorSchema( - query=Query, limits=Limits(depth=3, selections=None, complexity=None) + query=Query, limits=Limits(depth=3, selections=None, complexity=None, gas=None) ) diff --git a/tests/testgraphene_global.py b/tests/testgraphene_global.py index e601eda..9c1ab6a 100644 --- a/tests/testgraphene_global.py +++ b/tests/testgraphene_global.py @@ -29,12 +29,12 @@ def test_defaults(self): for field in fields(dlimit): with self.subTest(f"Test default: {field.name}"): limit = getattr(dlimit, field.name) - self.assertTrue(limit is None or limit >= 0) + self.assertTrue(limit is None or isinstance(limit, set) or limit >= 0) def test_depth(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=2, selections=None, complexity=None), + limits=Limits(depth=2, selections=None, complexity=None, gas=None), ) with self.subTest("success"): @@ -93,7 +93,7 @@ def test_depth(self): def test_selections(self): schema = ProtectorSchema( query=Query, - limits=Limits(selections=2, depth=None, complexity=None), + limits=Limits(selections=2, depth=None, complexity=None, gas=None), ) with self.subTest("success 1"): query = """ @@ -169,7 +169,7 @@ def test_selections(self): def test_complexity(self): schema = ProtectorSchema( query=Query, - limits=Limits(selections=None, depth=None, complexity=2), + limits=Limits(selections=None, depth=None, complexity=2, gas=None), ) with self.subTest("success 1"): query = """ diff --git a/tests/teststrawberry_base.py b/tests/teststrawberry_base.py index 27c5dc1..1b26617 100644 --- a/tests/teststrawberry_base.py +++ b/tests/teststrawberry_base.py @@ -15,7 +15,7 @@ class CustomSchema(SchemaMixin, StrawberrySchema): protector_default_limits = Limits( - depth=2, selections=None, complexity=None + depth=2, selections=None, complexity=None, gas=None ) @@ -23,7 +23,7 @@ class CustomSchemaWithoutOperationWrapping( SchemaMixin, StrawberrySchema, protector_per_operation_validation=False ): protector_default_limits = Limits( - depth=2, selections=None, complexity=None + depth=2, selections=None, complexity=None, gas=None ) @@ -31,7 +31,7 @@ class TestStrawberry(unittest.IsolatedAsyncioTestCase): def test_simple_sync(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=2, selections=None, complexity=None), + limits=Limits(depth=2, selections=None, complexity=None, gas=None), ) self.assertIsInstance(schema, StrawberrySchema) self.assertTrue( @@ -56,7 +56,7 @@ def test_simple_sync(self): def test_failing_sync(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=1, selections=None, complexity=None), + limits=Limits(depth=1, selections=None, complexity=None, gas=None), ) self.assertIsInstance(schema, StrawberrySchema) self.assertTrue( @@ -82,7 +82,7 @@ def test_failing_sync(self): async def test_failing_async(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=1, selections=None, complexity=None), + limits=Limits(depth=1, selections=None, complexity=None, gas=None), ) self.assertIsInstance(schema, StrawberrySchema) self.assertTrue( @@ -108,7 +108,7 @@ async def test_failing_async(self): def test_in_out(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=2, selections=None, complexity=None), + limits=Limits(depth=2, selections=None, complexity=None, gas=None), ) self.assertIsInstance(schema, StrawberrySchema) result = schema.execute_sync('{ inOut(into: ["a", "b"]) }') @@ -118,7 +118,7 @@ def test_in_out(self): async def test_simple_async(self): schema = ProtectorSchema( query=Query, - limits=Limits(depth=2, selections=None, complexity=None), + limits=Limits(depth=2, selections=None, complexity=None, gas=None), ) self.assertIsInstance(schema, StrawberrySchema) result = await schema.execute( @@ -154,7 +154,7 @@ async def test_failing_async_custom(self): query=Query, extensions=[ CustomGrapheneProtector( - limits=Limits(depth=2, selections=None, complexity=None), + limits=Limits(depth=2, selections=None, complexity=None, gas=None), ) ], ) @@ -201,9 +201,7 @@ async def test_success_node_async(self): % (to_base64("SomeNode", "foo")) ) self.assertFalse(result.errors) - self.assertEqual( - result.data["node"]["id"], to_base64("SomeNode", "foo") - ) + self.assertEqual(result.data["node"]["id"], to_base64("SomeNode", "foo")) async def test_success_connection_async(self): schema = ProtectorSchema( @@ -234,8 +232,6 @@ async def test_success_connection_async(self): self.assertFalse(result.errors) self.assertEqual(len(result.data["someNodes"]["edges"]), 100) self.assertEqual( - from_base64(result.data["someNodes"]["edges"][99]["node"]["id"])[ - 1 - ], + from_base64(result.data["someNodes"]["edges"][99]["node"]["id"])[1], "id-199", )