Skip to content

Commit

Permalink
add passthrough (Limits) and gas (new unit)
Browse files Browse the repository at this point in the history
  • Loading branch information
devkral committed Mar 15, 2024
1 parent 61c57e7 commit 2143eac
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 98 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ If you want some new or better algorithms integrated just make a PR

# TODO

- add tests and documentation for passthrough
- add tests and documentation for gas
- stop when an open path regex is used. May append an invalid char and check if it is still ignoring
- keep an eye on the performance impact of the new path regex checking
- add tests for auto_snakecase and camelcase_path
Expand Down
195 changes: 131 additions & 64 deletions graphene_protector/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from collections.abc import Callable
from dataclasses import fields, replace
from functools import wraps
from typing import Callable, List, Tuple
from typing import List, Tuple, Union

from graphql import GraphQLInterfaceType, GraphQLObjectType, GraphQLUnionType
from graphql.error import GraphQLError
Expand All @@ -24,8 +25,10 @@
ComplexityLimitReached,
DepthLimitReached,
EarlyStop,
GasLimitReached,
Limits,
SelectionsLimitReached,
UsagesResult,
default_path_ignore_pattern,
)

Expand All @@ -45,9 +48,7 @@ def to_camel_case(snake_str):
components = snake_str.split("_")
# We capitalize the first letter of each component except the first one
# with the 'capitalize' method and join them together.
return components[0] + "".join(
x.capitalize() if x else "_" for x in components[1:]
)
return components[0] + "".join(x.capitalize() if x else "_" for x in components[1:])


# From this response in Stackoverflow
Expand All @@ -68,17 +69,31 @@ def merge_limits(old_limits: Limits, new_limits: Limits):
return replace(old_limits, **_limits)


def _extract_limits(scheme_field) -> Limits:
def _extract_limits(schema_field) -> Limits:
while True:
if hasattr(scheme_field, "_graphene_protector_limits"):
return getattr(scheme_field, "_graphene_protector_limits")
if hasattr(scheme_field, "__func__"):
scheme_field = getattr(scheme_field, "__func__")
if hasattr(schema_field, "_graphene_protector_limits"):
return getattr(schema_field, "_graphene_protector_limits")
if hasattr(schema_field, "__func__"):
schema_field = getattr(schema_field, "__func__")
else:
break
return MISSING_LIMITS


def gas_for_field(schema_field, **kwargs) -> int:
while True:
if hasattr(schema_field, "_graphene_protector_gas"):
retval = getattr(schema_field, "_graphene_protector_gas")
if callable(retval):
retval = retval(**kwargs)
return retval
if hasattr(schema_field, "__func__"):
schema_field = getattr(schema_field, "__func__")
else:
break
return 0


def limits_for_field(field, old_limits, **kwargs) -> Tuple[Limits, Limits]:
# retrieve optional limitation attributes defined for the current
# operation
Expand All @@ -98,21 +113,21 @@ def check_resource_usage(
camelcase_path=True,
path_ignore_pattern: re.Pattern = _default_path_ignore_pattern,
get_limits_for_field=limits_for_field,
get_gas_for_field=gas_for_field,
level_depth=0,
level_complexity=0,
_seen_limits=None,
_path="",
) -> Tuple[int, int, int]:
) -> UsagesResult:
if _seen_limits is None:
_seen_limits = set()
# level 0: starts on query level. Every query is level 1
selections = 0
max_level_depth = level_depth
max_level_complexity = level_complexity
assert (
limits.depth is not MISSING
), "missing should be already resolved here"
if limits.depth and max_level_depth > limits.depth:
retval = UsagesResult(
max_level_depth=level_depth,
max_level_complexity=level_complexity,
)
assert limits.depth is not MISSING, "missing should be already resolved here"
if limits.depth and retval.max_level_depth > limits.depth:
on_error(
DepthLimitReached(
"Query is too deep",
Expand All @@ -133,7 +148,8 @@ def check_resource_usage(

if isinstance(field, (GraphQLUnionType, GraphQLInterfaceType)):
merged_limits = limits
local_selections = 0
local_union_selections = 0
local_gas = 0

field_contributes_to_score = True
_npath = "{}/{}".format(
Expand All @@ -142,14 +158,8 @@ def check_resource_usage(
)
if path_ignore_pattern.match(_npath):
field_contributes_to_score = False
for field_type in validation_context.schema.get_possible_types(
field
):
(
new_depth,
new_depth_complexity,
local2_selections,
) = check_resource_usage(
for field_type in validation_context.schema.get_possible_types(field):
local_result = check_resource_usage(
follow_of_type(field_type),
field,
validation_context,
Expand All @@ -159,6 +169,7 @@ def check_resource_usage(
camelcase_path=camelcase_path,
path_ignore_pattern=path_ignore_pattern,
get_limits_for_field=get_limits_for_field,
get_gas_for_field=get_gas_for_field,
level_depth=level_depth + 1
if field_contributes_to_score
else level_depth,
Expand All @@ -173,23 +184,28 @@ def check_resource_usage(
# called per query, selection
if (
merged_limits.complexity
and (new_depth_complexity - level_complexity)
* local2_selections
and (local_result.max_level_complexity - level_complexity)
* local_result.selections
> merged_limits.complexity
):
on_error(
ComplexityLimitReached("Query is too complex", node)
)
on_error(ComplexityLimitReached("Query is too complex", node))
# find max of selections for unions
if local_result.selections > local_union_selections:
local_union_selections = local_result.selections
# find max of selections for unions
if local2_selections > local_selections:
local_selections = local2_selections
if new_depth > max_level_depth:
max_level_depth = new_depth
if new_depth_complexity > max_level_complexity:
max_level_complexity = new_depth_complexity
if local_result.gas_used > local_gas:
local_gas = local_result.gas_used
if local_result.max_level_depth > retval.max_level_depth:
retval.max_level_depth = local_result.max_level_depth
if local_result.max_level_complexity > retval.max_level_complexity:
retval.max_level_complexity = local_result.max_level_complexity
# ignore union fields itself for selection_count
# because we have depth for that
selections += local_selections
retval.selections += local_union_selections
# gas for field itself already calculated in field.selection_set
retval.gas_used += local_gas
del local_union_selections
del local_gas
elif field.selection_set:
try:
schema_field = getattr(schema, fieldname)
Expand All @@ -213,6 +229,12 @@ def check_resource_usage(
parent=schema,
fieldname=fieldname,
)
# add gas for selection field
retval.gas_used += get_gas_for_field(
schema_field,
parent=schema,
fieldname=fieldname,
)
allow_reset = True
field_contributes_to_score = True
_npath = "{}/{}".format(
Expand All @@ -234,11 +256,7 @@ def check_resource_usage(
sub_field_type = schema_field
else:
sub_field_type = follow_of_type(schema_field.type)
(
new_depth,
new_depth_complexity,
local_selections,
) = check_resource_usage(
local_result = check_resource_usage(
sub_field_type,
field,
validation_context,
Expand All @@ -248,6 +266,7 @@ def check_resource_usage(
camelcase_path=camelcase_path,
path_ignore_pattern=path_ignore_pattern,
get_limits_for_field=get_limits_for_field,
get_gas_for_field=get_gas_for_field,
# field_contributes_to_score will be casted to 1 for True
level_depth=level_depth + field_contributes_to_score
if sub_limits.depth is MISSING or not allow_reset
Expand All @@ -261,33 +280,71 @@ def check_resource_usage(
# called per query, selection
if (
merged_limits.complexity
and (new_depth - level_depth) * local_selections
and (local_result.max_level_depth - level_depth)
* local_result.selections
> merged_limits.complexity
):
on_error(ComplexityLimitReached("Query is too complex", node))
# increase level counter only if limits are not redefined
if sub_limits.depth is MISSING and new_depth > max_level_depth:
max_level_depth = new_depth
if (
sub_limits.depth is MISSING or "depth" in sub_limits.passthrough
) and local_result.max_level_depth > retval.max_level_depth:
retval.max_level_depth = local_result.max_level_depth
if (
sub_limits.complexity is MISSING
and new_depth_complexity > max_level_complexity
):
max_level_complexity = new_depth_complexity
or "complexity" in sub_limits.passthrough
) and local_result.max_level_complexity > retval.max_level_complexity:
retval.max_level_complexity = local_result.max_level_complexity

# ignore fields with selection_set itself for selection_count
# because we have depth for that
if sub_limits.selections is MISSING:
selections += local_selections
if (
sub_limits.selections is MISSING
or "selections" in sub_limits.passthrough
):
retval.selections += local_result.selections
if sub_limits.gas is MISSING or "gas" in sub_limits.passthrough:
retval.gas_used += local_result.gas_used
del schema_field
else:
field_contributes_to_score = True
if path_ignore_pattern.match(_path):
field_contributes_to_score = False
# field_contributes_to_score will be casted to 1 for True
selections += field_contributes_to_score
try:
schema_field = getattr(schema, fieldname)
except AttributeError:
_name = None
if hasattr(field, "name"):
_name = field.name
if hasattr(_name, "value"):
_name = _name.value
if (
hasattr(schema, "fields")
and not isinstance(schema, GraphQLInterfaceType)
and _name
):
schema_field = schema.fields[_name]
else:
schema_field = schema
retval.gas_used += get_gas_for_field(
schema_field,
parent=schema,
fieldname=fieldname,
)
if not path_ignore_pattern.match(_path):
# field_contributes_to_score
retval.selections += 1

if limits.selections and selections > limits.selections:
if limits.selections and retval.selections > limits.selections:
on_error(SelectionsLimitReached("Query selects too much", node))
return max_level_depth, max_level_complexity, selections
if limits.gas and retval.gas_used > limits.gas:
on_error(GasLimitReached("Query uses too much gas", node))
return retval


def gas_usage(gas_used: Union[Callable[[], int], int]):
def wrapper(schema_field):
setattr(schema_field, "_graphene_protector_gas", gas_used)
return schema_field

return wrapper


class LimitsValidationRule(ValidationRule):
Expand Down Expand Up @@ -351,6 +408,7 @@ def enter(self, node, key, parent, path, ancestors):
maintype = schema.get_type(operation_type)
assert maintype is not None
get_limits_for_field = limits_for_field
get_gas_for_field = gas_for_field
if hasattr(maintype, "graphene_type"):
maintype = maintype.graphene_type
if hasattr(schema, "_strawberry_schema"):
Expand All @@ -360,9 +418,7 @@ def get_limits_for_field(
):
name = follow_of_type(parent).name
definition = (
schema._strawberry_schema.schema_converter.type_map[
name
]
schema._strawberry_schema.schema_converter.type_map[name]
).definition
# e.g. union
if not hasattr(definition, "get_field"):
Expand All @@ -371,6 +427,18 @@ def get_limits_for_field(
nfield = definition.get_field(fieldname)
return limits_for_field(nfield, old_limits)

def get_gas_for_field(field, parent, fieldname, **kwargs):
name = follow_of_type(parent).name
definition = (
schema._strawberry_schema.schema_converter.type_map[name]
).definition
# e.g. union
if not hasattr(definition, "get_field"):
return gas_for_field(definition)

nfield = definition.get_field(fieldname)
return gas_for_field(nfield)

if getattr(self, "protector_on", True):
try:
check_resource_usage(
Expand All @@ -380,6 +448,7 @@ def get_limits_for_field(
self.default_limits,
self.report_error,
get_limits_for_field=get_limits_for_field,
get_gas_for_field=get_gas_for_field,
auto_snakecase=self.auto_snakecase,
camelcase_path=self.camelcase_path,
path_ignore_pattern=self.path_ignore_pattern,
Expand Down Expand Up @@ -461,9 +530,7 @@ class SchemaMixin:
protector_default_limits = None
protector_path_ignore_pattern = default_path_ignore_pattern

def __init_subclass__(
cls, protector_per_operation_validation=True, **kwargs
):
def __init_subclass__(cls, protector_per_operation_validation=True, **kwargs):
if hasattr(cls, "execute_sync"):
cls.execute_sync = decorate_limits(
cls.execute_sync, protector_per_operation_validation
Expand Down
3 changes: 3 additions & 0 deletions graphene_protector/django/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def get_protector_default_limits(self):
complexity=_get_default_limit_from_settings(
"GRAPHENE_PROTECTOR_COMPLEXITY_LIMIT"
),
gas=_get_default_limit_from_settings(
"GRAPHENE_PROTECTOR_GAS_LIMIT"
),
),
),
self.protector_default_limits,
Expand Down
Loading

0 comments on commit 2143eac

Please sign in to comment.