From ac03bce56648a8e00c4c82bc2d4d9ba921b7b44e Mon Sep 17 00:00:00 2001 From: Jessica Gadling Date: Wed, 5 Feb 2025 11:44:24 -0800 Subject: [PATCH] feat!: remove file handling (#124) * Updating dependencies. * Upgrade strawberry-graphql. * Allow Strawberry's GQL errors to pass through without changes. * Remove file handling. --- platformics/cli/main.py | 4 +- platformics/codegen/generator.py | 15 +- platformics/codegen/lib/linkml_wrappers.py | 59 +- .../templates/database/models/__init__.py.j2 | 4 +- .../database/models/class_name.py.j2 | 21 +- .../graphql_api/helpers/class_name.py.j2 | 8 +- .../templates/graphql_api/mutations.py.j2 | 11 - .../templates/graphql_api/queries.py.j2 | 8 - .../graphql_api/types/class_name.py.j2 | 66 +- .../test_infra/factories/class_name.py.j2 | 26 +- .../templates/validators/class_name.py.j2 | 4 +- platformics/database/models/__init__.py | 1 - platformics/database/models/base.py | 5 - platformics/database/models/file.py | 106 --- platformics/graphql_api/core/gql_loaders.py | 2 +- platformics/graphql_api/core/query_builder.py | 4 +- platformics/graphql_api/files.py | 613 ------------------ platformics/graphql_api/setup.py | 3 - platformics/security/authorization.py | 11 +- platformics/support/file_enums.py | 24 - platformics/support/format_handlers.py | 129 ---- platformics/test_infra/factories/base.py | 61 +- platformics/test_infra/main.py | 106 --- test_app/conftest.py | 3 +- test_app/schema/schema.yaml | 24 +- test_app/scripts/seed.py | 4 +- test_app/tests/test_cascade_deletion.py | 7 +- test_app/tests/test_file_concatenation.py | 100 --- test_app/tests/test_file_mutations.py | 309 --------- test_app/tests/test_file_queries.py | 104 --- test_app/tests/test_file_uploads.py | 147 ----- test_app/tests/test_where_clause.py | 3 +- 32 files changed, 118 insertions(+), 1874 deletions(-) delete mode 100644 platformics/database/models/file.py delete mode 100644 platformics/graphql_api/files.py delete mode 100644 platformics/support/file_enums.py delete mode 100644 platformics/support/format_handlers.py delete mode 100644 platformics/test_infra/main.py delete mode 100644 test_app/tests/test_file_concatenation.py delete mode 100644 test_app/tests/test_file_mutations.py delete mode 100644 test_app/tests/test_file_queries.py delete mode 100644 test_app/tests/test_file_uploads.py diff --git a/platformics/cli/main.py b/platformics/cli/main.py index c7b0b55..eac6021 100644 --- a/platformics/cli/main.py +++ b/platformics/cli/main.py @@ -43,20 +43,18 @@ def api() -> None: @api.command("generate") @click.option("--schemafile", type=str, required=True) @click.option("--output-prefix", type=str, required=True) -@click.option("--render-files/--skip-render-files", type=bool, default=True, show_default=True) @click.option("--template-override-paths", type=str, multiple=True) @click.pass_context def api_generate( ctx: click.Context, schemafile: str, output_prefix: str, - render_files: bool, template_override_paths: tuple[str], ) -> None: """ Launch code generation """ - generate(schemafile, output_prefix, render_files, template_override_paths) + generate(schemafile, output_prefix, template_override_paths) @cli.group() diff --git a/platformics/codegen/generator.py b/platformics/codegen/generator.py index 4e43b1b..bdbf3b4 100644 --- a/platformics/codegen/generator.py +++ b/platformics/codegen/generator.py @@ -62,7 +62,6 @@ def generate_entity_subclass_files( template_filename: str, environment: Environment, view: ViewWrapper, - render_files: bool, ) -> None: """ Code generation for SQLAlchemy models, GraphQL types, Cerbos policies, and Factoryboy factories @@ -76,13 +75,11 @@ def generate_entity_subclass_files( override_template = environment.get_template(f"{dest_filename}.j2") content = override_template.render( cls=entity, - render_files=render_files, view=view, ) else: content = template.render( cls=entity, - render_files=render_files, view=view, ) with open(os.path.join(output_prefix, dest_filename), mode="w", encoding="utf-8") as outfile: @@ -94,7 +91,6 @@ def generate_entity_import_files( output_prefix: str, environment: Environment, view: ViewWrapper, - render_files: bool, ) -> None: """ Code generation for database model imports, and GraphQL queries/mutations @@ -116,7 +112,6 @@ def generate_entity_import_files( import_template = environment.get_template(f"{filename}.j2") content = import_template.render( classes=classes, - render_files=render_files, view=view, ) with open(os.path.join(output_prefix, filename), mode="w", encoding="utf-8") as outfile: @@ -134,7 +129,7 @@ def regex_replace(txt, rgx, val, ignorecase=False, multiline=False): return compiled_rgx.sub(val, txt) -def generate(schemafile: str, output_prefix: str, render_files: bool, template_override_paths: tuple[str]) -> None: +def generate(schemafile: str, output_prefix: str, template_override_paths: tuple[str]) -> None: """ Launch code generation """ @@ -157,7 +152,7 @@ def generate(schemafile: str, output_prefix: str, render_files: bool, template_o # Generate enums and import files generate_enums(output_prefix, environment, wrapped_view) generate_limit_offset_type(output_prefix, environment) - generate_entity_import_files(output_prefix, environment, wrapped_view, render_files=render_files) + generate_entity_import_files(output_prefix, environment, wrapped_view) # Generate database models, GraphQL types, Cerbos policies, and Factoryboy factories generate_entity_subclass_files( @@ -165,40 +160,34 @@ def generate(schemafile: str, output_prefix: str, render_files: bool, template_o "database/models/class_name.py", environment, wrapped_view, - render_files=render_files, ) generate_entity_subclass_files( output_prefix, "graphql_api/types/class_name.py", environment, wrapped_view, - render_files=render_files, ) generate_entity_subclass_files( output_prefix, "validators/class_name.py", environment, wrapped_view, - render_files=render_files, ) generate_entity_subclass_files( output_prefix, "cerbos/policies/class_name.yaml", environment, wrapped_view, - render_files=render_files, ) generate_entity_subclass_files( output_prefix, "test_infra/factories/class_name.py", environment, wrapped_view, - render_files=render_files, ) generate_entity_subclass_files( output_prefix, "graphql_api/helpers/class_name.py", environment, wrapped_view, - render_files=render_files, ) diff --git a/platformics/codegen/lib/linkml_wrappers.py b/platformics/codegen/lib/linkml_wrappers.py index 58c1582..171b07f 100644 --- a/platformics/codegen/lib/linkml_wrappers.py +++ b/platformics/codegen/lib/linkml_wrappers.py @@ -5,6 +5,7 @@ functions to keep complicated LinkML-specific logic out of our Jinja2 templates. """ +import contextlib from functools import cached_property import strcase @@ -35,10 +36,19 @@ def identifier(self) -> str: def name(self) -> str: return self.wrapped_field.name.replace(" ", "_") + @cached_property + def description(self) -> str: + # Make sure to quote this so it's safe! + return repr(self.wrapped_field.description) + @cached_property def camel_name(self) -> str: return strcase.to_lower_camel(self.name) + @cached_property + def type_designator(self) -> bool: + return bool(self.wrapped_field.designates_type) + @cached_property def multivalued(self) -> str: return self.wrapped_field.multivalued @@ -47,6 +57,10 @@ def multivalued(self) -> str: def required(self) -> bool: return self.wrapped_field.required or False + @cached_property + def designates_type(self) -> bool: + return self.wrapped_field.designates_type + # Validation attributes @cached_property def minimum_value(self) -> float | int | None: @@ -64,6 +78,11 @@ def maximum_value(self) -> float | int | None: def indexed(self) -> bool: if "indexed" in self.wrapped_field.annotations: return self.wrapped_field.annotations["indexed"].value + if self.identifier: + return True + with contextlib.suppress(NotImplementedError, AttributeError, ValueError): + if self.related_class.identifier: + return True return False @cached_property @@ -94,9 +113,7 @@ def hidden(self) -> bool: @cached_property def readonly(self) -> bool: is_readonly = self.wrapped_field.readonly - if is_readonly: - return True - return False + return bool(is_readonly) # Whether these fields should be available to change via an `Update` mutation # All fields are mutable by default, so long as they're not marked as readonly @@ -138,16 +155,12 @@ def inverse_field(self) -> str: @cached_property def is_enum(self) -> bool: field = self.view.get_element(self.wrapped_field.range) - if isinstance(field, EnumDefinition): - return True - return False + return bool(isinstance(field, EnumDefinition)) @cached_property def is_entity(self) -> bool: field = self.view.get_element(self.wrapped_field.range) - if isinstance(field, ClassDefinition): - return True - return False + return bool(isinstance(field, ClassDefinition)) @property def related_class(self) -> "EntityWrapper": @@ -163,6 +176,17 @@ def factory_type(self) -> str | None: return self.wrapped_field.annotations["factory_type"].value return None + @cached_property + def is_single_parent(self) -> bool: + # TODO, this parameter probably needs a better name. It's entirely SA specific right now. + # Basically we need it to tell SQLAlchemy that we have a 1:many relationship without a backref. + # Normally that's fine on its own, but if SQLALchemy will not allow us to enable cascading + # deletes unless we promise it (with this flag) that a given "child" object has only one parent, + # thereby making it safe to delete when the parent is deleted + if "single_parent" in self.wrapped_field.annotations: + return self.wrapped_field.annotations["single_parent"].value + return False + @cached_property def is_cascade_delete(self) -> bool: if "cascade_delete" in self.wrapped_field.annotations: @@ -234,15 +258,15 @@ def writable_fields(self) -> list[FieldWrapper]: return [FieldWrapper(self.view, item) for item in self.view.class_induced_slots(self.name) if not item.readonly] @cached_property - def identifier(self) -> str: - # Prioritize sending back identifiers from the entity mixin instead of inherited fields. + def identifier(self) -> FieldWrapper: + # Prioritize sending back identifiers from the current class and mixins instead of inherited fields. + domains_owned_by_this_class = set(self.wrapped_class.mixins + [self.name]) for field in self.all_fields: - # FIXME, the entity.id / entity_id relationship is a little brittle right now :( - if field.identifier and "EntityMixin" in field.wrapped_field.domain_of: - return field.name + if field.identifier and domains_owned_by_this_class.intersection(set(field.wrapped_field.domain_of)): + return field for field in self.all_fields: - if field.identifier: - return field.name + if field.identifier and self.name in field.wrapped_field.domain_of: + return field raise Exception("No identifier found") @cached_property @@ -343,6 +367,9 @@ def enums(self) -> list[EnumWrapper]: enums = [] for enum_name in self.view.all_enums(): enum = self.view.get_element(enum_name) + # Don't codegen stuff that users asked us not to. + if enum.annotations.get("skip_codegen") and enum.annotations["skip_codegen"].value: + continue enums.append(EnumWrapper(self.view, enum)) return enums diff --git a/platformics/codegen/templates/database/models/__init__.py.j2 b/platformics/codegen/templates/database/models/__init__.py.j2 index f70c376..6c07e96 100644 --- a/platformics/codegen/templates/database/models/__init__.py.j2 +++ b/platformics/codegen/templates/database/models/__init__.py.j2 @@ -9,13 +9,11 @@ Make changes to the template codegen/templates/database/models/__init__.py.j2 in from sqlalchemy.orm import configure_mappers -from platformics.database.models import Base, meta, Entity, File, FileStatus # noqa: F401 +from platformics.database.models import Base, meta, Entity # noqa: F401 {%- for class in classes %} {%- if class.snake_name != "Entity" %} from database.models.{{ class.snake_name }} import {{ class.name }} # noqa: F401 {%- endif %} {%- endfor %} -from platformics.database.models.file import File, FileStatus # noqa: F401 - configure_mappers() diff --git a/platformics/codegen/templates/database/models/class_name.py.j2 b/platformics/codegen/templates/database/models/class_name.py.j2 index bb056dc..60d5d88 100644 --- a/platformics/codegen/templates/database/models/class_name.py.j2 +++ b/platformics/codegen/templates/database/models/class_name.py.j2 @@ -6,7 +6,7 @@ Make changes to the template codegen/templates/database/models/class_name.py.j2 """ {% set related_fields = cls.related_fields | unique(attribute='related_class.name') | list %} -{% set ignored_fields = ["File", "Entity", cls.name] %} +{% set ignored_fields = ["Entity", cls.name] %} import uuid import datetime @@ -26,19 +26,19 @@ from support.enums import {%- endfor %} if TYPE_CHECKING: - from database.models.file import File {%- for related_field in related_fields %} {%- if related_field.related_class.name not in ignored_fields %} from database.models.{{related_field.related_class.snake_name}} import {{related_field.related_class.name}} {%- endif %} {%- endfor %} + pass else: - File = "File" {%- for related_field in related_fields %} {%- if related_field.related_class.name not in ignored_fields %} {{related_field.related_class.name}} = "{{related_field.related_class.name}}" {%- endif %} {%- endfor %} + pass class {{cls.name}}(Entity): @@ -65,12 +65,13 @@ class {{cls.name}}(Entity): {%- elif attr.type == "date" %} {{attr.name}}: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True), nullable={{ "False" if attr.required else "True"}}{%- if attr.indexed%}, index=True{%- endif %}) {%- else %} + {%- if attr.is_single_parent -%} + {% set single_parent = 'single_parent=True,' %} + {%- else -%} + {% set single_parent = "" %} + {%- endif -%} {%- if attr.is_cascade_delete -%} - {%- if attr.type == "File" -%} - {% set cascade = 'cascade="all, delete-orphan", single_parent=True, post_update=True' %} - {%- else -%} - {% set cascade = 'cascade="all, delete-orphan"' %} - {%- endif -%} + {% set cascade = 'cascade="all, delete-orphan",' %} {%- else -%} {% set cascade = "" %} {%- endif -%} @@ -83,11 +84,12 @@ class {{cls.name}}(Entity): uselist=True, foreign_keys="{{attr.type}}.{{attr.inverse_field}}_id", {{cascade}} + {{single_parent}} ) {%- else %} {{attr.name}}_id: Mapped[uuid.UUID] = mapped_column( UUID, - ForeignKey("{{attr.related_class.snake_name}}.{{attr.related_class.identifier}}"), + ForeignKey("{{attr.related_class.snake_name}}.{{attr.related_class.identifier.name}}"), nullable={{"False" if attr.required else "True"}}, {%- if attr.identifier %} primary_key=True, @@ -103,6 +105,7 @@ class {{cls.name}}(Entity): back_populates="{{attr.inverse_field}}", {%- endif %} {{cascade}} + {{single_parent}} ) {%- endif %} {%- endif %} diff --git a/platformics/codegen/templates/graphql_api/helpers/class_name.py.j2 b/platformics/codegen/templates/graphql_api/helpers/class_name.py.j2 index 111c8fc..53856c4 100644 --- a/platformics/codegen/templates/graphql_api/helpers/class_name.py.j2 +++ b/platformics/codegen/templates/graphql_api/helpers/class_name.py.j2 @@ -5,7 +5,7 @@ Auto-gereanted by running 'make codegen'. Do not edit. Make changes to the template codegen/templates/graphql_api/groupby_helpers.py.j2 instead. """ {% set related_fields = cls.related_fields | unique(attribute='related_class.name') | list %} -{% set ignored_fields = ["File", "Entity", cls.name] %} +{% set ignored_fields = ["Entity", cls.name] %} from typing import Any, Optional import strawberry @@ -35,8 +35,8 @@ These are only used in aggregate queries. @strawberry.type class {{ cls.name }}GroupByOptions: {%- for attr in cls.visible_fields %} - {%- if attr.type != "File" and not attr.multivalued and not attr.is_virtual_relationship %} - {%- if attr.inverse %} + {%- if not attr.multivalued and not attr.is_virtual_relationship %} + {%- if attr.is_entity %} {{ attr.name }}: Optional[{{ attr.related_class.name }}GroupByOptions] = None {%- elif attr.type == cls.name %} {{ attr.name }}: Optional["{{ cls.name }}GroupByOptions"] = None @@ -78,7 +78,7 @@ def build_{{ cls.snake_name }}_groupby_output( key = keys.pop(0) match key: {%- for attr in cls.visible_fields %} - {%- if attr.type != "File" and not attr.multivalued and not attr.is_virtual_relationship %} + {%- if not attr.multivalued and not attr.is_virtual_relationship %} {%- if attr.inverse %} case "{{ attr.related_class.snake_name }}": if getattr(group_object, key): diff --git a/platformics/codegen/templates/graphql_api/mutations.py.j2 b/platformics/codegen/templates/graphql_api/mutations.py.j2 index 95d1101..f49dc5b 100644 --- a/platformics/codegen/templates/graphql_api/mutations.py.j2 +++ b/platformics/codegen/templates/graphql_api/mutations.py.j2 @@ -7,9 +7,6 @@ Make changes to the template codegen/templates/graphql_api/mutations.py.j2 inste import strawberry from typing import Sequence -{%- if render_files %} -from platformics.graphql_api.files import File, create_file, upload_file, upload_temporary_file, mark_upload_complete, concatenate_files, SignedURL, MultipartUploadResponse -{%- endif %} {%- for class in classes %} from graphql_api.types.{{ class.snake_name }} import {{ class.name }}, {%- if class.create_fields %}create_{{ class.snake_name }}, {%- endif %}{%- if class.mutable_fields %}update_{{ class.snake_name }}, {%- endif %}delete_{{ class.snake_name }} @@ -17,14 +14,6 @@ from graphql_api.types.{{ class.snake_name }} import {{ class.name }}, {%- if cl @strawberry.type class Mutation: - {%- if render_files %} - # File mutations - create_file: File = create_file - upload_file: MultipartUploadResponse = upload_file - upload_temporary_file: MultipartUploadResponse = upload_temporary_file - mark_upload_complete: File = mark_upload_complete - concatenate_files: SignedURL = concatenate_files - {%- endif %} {%- for class in classes %} # {{ class.name }} mutations diff --git a/platformics/codegen/templates/graphql_api/queries.py.j2 b/platformics/codegen/templates/graphql_api/queries.py.j2 index e8aa849..32155af 100644 --- a/platformics/codegen/templates/graphql_api/queries.py.j2 +++ b/platformics/codegen/templates/graphql_api/queries.py.j2 @@ -8,9 +8,6 @@ Make changes to the template codegen/templates/graphql_api/queries.py.j2 instead import strawberry from platformics.graphql_api import relay from typing import Sequence, List -{%- if render_files %} -from platformics.graphql_api.files import File, resolve_files -{%- endif %} {%- for class in classes %} from graphql_api.types.{{ class.snake_name }} import {{ class.name }}, resolve_{{ class.plural_snake_name }}, {{ class.name }}Aggregate, resolve_{{ class.plural_snake_name }}_aggregate {%- endfor %} @@ -22,11 +19,6 @@ class Query: node: relay.Node = relay.node() nodes: List[relay.Node] = relay.node() - {%- if render_files %} - # Query files - files: Sequence[File] = resolve_files - {%- endif %} - # Query entities {%- for class in classes %} {{ class.plural_snake_name }}: Sequence[{{ class.name }}] = resolve_{{ class.plural_snake_name }} diff --git a/platformics/codegen/templates/graphql_api/types/class_name.py.j2 b/platformics/codegen/templates/graphql_api/types/class_name.py.j2 index a300918..46ee456 100644 --- a/platformics/codegen/templates/graphql_api/types/class_name.py.j2 +++ b/platformics/codegen/templates/graphql_api/types/class_name.py.j2 @@ -7,8 +7,23 @@ Make changes to the template codegen/templates/graphql_api/types/class_name.py.j # ruff: noqa: E501 Line too long +{%- set type_map = { + "uuid": "strawberry.ID", + "string": "str", + "Array2dFloat": "List[List[float]]", + "integer": "int", + "float": "float", + "boolean": "bool", + "date": "datetime.datetime", +} %} +{% macro getType(type, required) -%} + {%- if required %} {{ type }} + {%- else %} Optional[{{ type }}] + {%- endif %} +{%- endmacro %} + {% set related_fields = cls.related_fields | unique(attribute='related_class.name') | list %} -{% set ignored_fields = ["File", "Entity", cls.name] %} +{% set ignored_fields = ["Entity", cls.name] %} import typing from typing import TYPE_CHECKING, Annotated, Any, Optional, Sequence, Callable, List @@ -24,9 +39,6 @@ from validators.{{cls.snake_name}} import {{cls.name}}CreateInputValidator {%- if cls.mutable_fields %} from validators.{{cls.snake_name}} import {{cls.name}}UpdateInputValidator {%- endif %} -{%- if render_files %} -from platformics.graphql_api.files import File, FileWhereClause -{%- endif %} from graphql_api.helpers.{{ cls.snake_name }} import {{ cls.name }}GroupByOptions, build_{{ cls.snake_name }}_groupby_output from platformics.graphql_api.types.entities import EntityInterface {%- for related_field in related_fields %} @@ -59,7 +71,7 @@ from support.enums import {%- endfor %} -E = typing.TypeVar("E", base_db.File, db.{{ cls.name }}) +E = typing.TypeVar("E") T = typing.TypeVar("T") if TYPE_CHECKING: @@ -88,7 +100,7 @@ These are batching functions for loading related objects to avoid N+1 queries. """ {%- for related_field in related_fields %} - {%- if related_field.inverse and related_field.related_class.name not in ignored_fields %} + {%- if related_field.is_entity and related_field.related_class.name not in ignored_fields %} {%- if related_field.multivalued %} @relay.connection( relay.ListConnection[Annotated["{{ related_field.type }}", strawberry.lazy("graphql_api.types.{{ related_field.related_class.snake_name }}")]] # type:ignore @@ -144,31 +156,6 @@ async def load_{{ related_field.related_class.snake_name }}_aggregate_rows( {%- endfor %} -{%- if cls.owned_fields | map(attribute="type") | select("equalto", "File") | list | length > 0 %} -""" ------------------------------------------------------------------------------- -Dataloader for File object ------------------------------------------------------------------------------- -""" - -def load_files_from(attr_name: str) -> Callable: - @strawberry.field - async def load_files( - root: "{{ cls.name }}", - info: Info, - where: Annotated["FileWhereClause", strawberry.lazy("platformics.graphql_api.files")] | None = None, - ) -> Optional[Annotated["File", strawberry.lazy("platformics.graphql_api.files")]]: - """ - Given a list of {{ cls.name }} IDs for a certain file type, return related Files - """ - dataloader = info.context["sqlalchemy_loader"] - mapper = inspect(db.{{ cls.name }}) - relationship = mapper.relationships[attr_name] - return await dataloader.loader_for(relationship, where).load(getattr(root, f"{attr_name}_id")) # type:ignore - - return load_files -{%- endif %} - """ ------------------------------------------------------------------------------ Define Strawberry GQL types @@ -204,7 +191,7 @@ class {{ cls.name }}WhereClause(TypedDict): {{ attr.name }}: Optional[BoolComparators] | None {%- elif attr.type == "date" %} {{ attr.name }}: Optional[DatetimeComparators] | None - {%- elif attr.inverse %} + {%- elif attr.is_entity %} {{ attr.name }}: Optional[Annotated["{{ attr.type }}WhereClause", strawberry.lazy("graphql_api.types.{{ attr.related_class.snake_name }}")]] | None {%- elif attr.type == cls.name %} {{ attr.name }}_id: Optional[UUIDComparators] | None @@ -217,8 +204,8 @@ Supported ORDER BY clause attributes @strawberry.input class {{ cls.name }}OrderByClause(TypedDict): {%- for attr in cls.visible_fields %} - {%- if attr.type != "File" and not attr.multivalued %} - {%- if attr.inverse %} + {%- if not attr.multivalued %} + {%- if attr.is_entity %} {{ attr.name }}: Optional[Annotated["{{ attr.type }}OrderByClause", strawberry.lazy("graphql_api.types.{{ attr.related_class.snake_name }}")]] | None {%- else %} {{ attr.name }}: Optional[orderBy] | None @@ -257,16 +244,17 @@ class {{ cls.name }}(EntityInterface): {{ attr.name }}: {{ getType("bool", attr.required) }} {%- elif attr.type == "date" %} {{ attr.name }}: {{ getType("datetime.datetime", attr.required) }} - {%- elif attr.type == "File" %} - {{ attr.name }}_id: Optional[strawberry.ID] - {{ attr.name }}: Optional[Annotated["File", strawberry.lazy("platformics.graphql_api.files")]] = load_files_from("{{ attr.name }}") # type: ignore {%- elif attr.type == cls.name %} {{ attr.name }}_id: Optional[strawberry.ID] - {%- elif attr.inverse %} + {%- elif attr.is_entity %} {{ attr.name }}: {{ "Sequence" if attr.multivalued else "Optional" }}[Annotated["{{ attr.type }}", strawberry.lazy("graphql_api.types.{{ attr.related_class.snake_name }}")]] = load_{{ attr.related_class.snake_name }}_rows # type:ignore {%- if attr.multivalued %} {{ attr.name }}_aggregate : Optional[Annotated["{{ attr.related_class.name }}Aggregate", strawberry.lazy("graphql_api.types.{{ attr.related_class.snake_name }}")]] = load_{{ attr.related_class.snake_name }}_aggregate_rows # type:ignore + {%- else %} + {{ attr.name }}_id : {{ getType(type_map[attr.related_class.identifier.type], attr.required) }} {%- endif %} + {%- else %} + {%- endif %} {%- endfor %} @@ -384,8 +372,6 @@ Mutation types {{ attr.name }}: {{ getTypeMutation(action, "bool", attr.required) }} {%- elif attr.type == "date" %} {{ attr.name }}: {{ getTypeMutation(action, "datetime.datetime", attr.required) }} - {%- elif attr.type == "File" %} - {{ attr.name }}_id: {{ getTypeMutation(action, "strawberry.ID", attr.required) }} {%- elif attr.is_entity and not attr.is_virtual_relationship %} {# Don't include multivalued fields, only fields where we can update an ID #} {{ attr.name }}_id: {{ getTypeMutation(action, "strawberry.ID", attr.required) }} {%- endif %} diff --git a/platformics/codegen/templates/test_infra/factories/class_name.py.j2 b/platformics/codegen/templates/test_infra/factories/class_name.py.j2 index 63a3b60..48eac4d 100644 --- a/platformics/codegen/templates/test_infra/factories/class_name.py.j2 +++ b/platformics/codegen/templates/test_infra/factories/class_name.py.j2 @@ -10,9 +10,9 @@ Make changes to the template codegen/templates/test_infra/factories/class_name.p import random import factory from database.models import {{ cls.name }} -from platformics.test_infra.factories.base import CommonFactory, FileFactory +from platformics.test_infra.factories.base import CommonFactory {%- for field in cls.related_fields %} - {%- if field.inverse and field.related_class.name != "Entity" and not field.multivalued%} + {%- if field.related_class.name != "Entity" and not field.multivalued%} from test_infra.factories.{{ field.related_class.snake_name }} import {{ field.related_class.name }}Factory {%- endif %} {%- endfor %} @@ -27,7 +27,7 @@ Faker.add_provider(EnumProvider) class {{ cls.name }}Factory(CommonFactory): -{%- if cls.name not in ["Entity", "File"] %} +{%- if cls.name not in ["Entity"] %} class Meta: sqlalchemy_session = None # workaround for a bug in factoryboy model = {{ cls.name }} @@ -36,31 +36,13 @@ class {{ cls.name }}Factory(CommonFactory): sqlalchemy_get_or_create = ("entity_id",) {% for field in cls.owned_fields %} {%- if field.type != "uuid" %} - {%- if field.inverse and field.related_class.name != "Entity" %} + {%- if field.is_entity and field.related_class.name != "Entity" %} {#- If the field is a one-to-one relationship, avoid circular imports by only defining the SubFactory on the child #} {%- if not field.multivalued and not field.is_virtual_relationship %} {{ field.name }} = factory.SubFactory( {{ field.related_class.name }}Factory, owner_user_id=factory.SelfAttribute("..owner_user_id"), collection_id=factory.SelfAttribute("..collection_id"), - ) - {%- endif %} - {%- elif field.type == "File" %} - {#- If the schema specifies what file_type to use, use that. #} - {#- Otherwise, default to fastq.}} #} - {%- if field.factory_type is not none %} - {{ field.name }} = factory.RelatedFactory( - FileFactory, - factory_related_name="entity", - entity_field_name="{{ field.name }}", - file_format="{{ field.factory_type }}", - ) - {%- else %} - {{ field.name }} = factory.RelatedFactory( - FileFactory, - factory_related_name="entity", - entity_field_name="{{ field.name }}", - file_format="fastq", ) {%- endif %} {%- elif field.is_enum %} diff --git a/platformics/codegen/templates/validators/class_name.py.j2 b/platformics/codegen/templates/validators/class_name.py.j2 index 44fcccf..a353fbe 100644 --- a/platformics/codegen/templates/validators/class_name.py.j2 +++ b/platformics/codegen/templates/validators/class_name.py.j2 @@ -8,7 +8,7 @@ Make changes to the template codegen/templates/validators/class_name.py.j2 inste # ruff: noqa: E501 Line too long {% set related_fields = cls.related_fields | unique(attribute='related_class.name') | list %} -{% set ignored_fields = ["File", "Entity", cls.name] %} +{% set ignored_fields = ["Entity", cls.name] %} {%- for field in cls.enum_fields %} {%- if loop.first %} @@ -80,8 +80,6 @@ from typing_extensions import Annotated {{ attr.name }}: Annotated[{{ getTypeValidation(action, "bool", attr.required) }}, Field()] {{ defaultNoneValue(action, attr.required) }} {%- elif attr.type == "date" %} {{ attr.name }}: Annotated[{{ getTypeValidation(action, "datetime.datetime", attr.required) }}, Field()] {{ defaultNoneValue(action, attr.required) }} - {%- elif attr.type == "File" %} - {{ attr.name }}_id: Annotated[{{ getTypeValidation(action, "uuid.UUID", attr.required) }}, Field()] {{ defaultNoneValue(action, attr.required) }} {%- elif attr.is_entity and not attr.is_virtual_relationship %} {# Don't include multivalued fields, only fields where we can update an ID #} {{ attr.name }}_id: Annotated[{{ getTypeValidation(action, "uuid.UUID", attr.required) }}, Field()] {{ defaultNoneValue(action, attr.required) }} {%- endif %} diff --git a/platformics/database/models/__init__.py b/platformics/database/models/__init__.py index 611a7b9..3f7f608 100644 --- a/platformics/database/models/__init__.py +++ b/platformics/database/models/__init__.py @@ -8,4 +8,3 @@ # isort: skip_file from platformics.database.models.base import Base, meta, Entity # noqa: F401 -from platformics.database.models.file import File, FileStatus # noqa: F401 diff --git a/platformics/database/models/base.py b/platformics/database/models/base.py index 18d2725..23c4372 100644 --- a/platformics/database/models/base.py +++ b/platformics/database/models/base.py @@ -8,11 +8,6 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.sql import func -if TYPE_CHECKING: - from platformics.database.models.file import File -else: - File = "File" - meta = MetaData( naming_convention={ "ix": "ix_%(column_0_label)s", diff --git a/platformics/database/models/file.py b/platformics/database/models/file.py deleted file mode 100644 index 490f0bd..0000000 --- a/platformics/database/models/file.py +++ /dev/null @@ -1,106 +0,0 @@ -import datetime -import uuid -from typing import ClassVar - -import uuid6 -from mypy_boto3_s3.client import S3Client -from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String, event -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.engine import Connection -from sqlalchemy.orm import Mapped, Mapper, mapped_column, relationship -from sqlalchemy.sql import func - -from platformics.database.models.base import Base, Entity -from platformics.settings import APISettings -from platformics.support.file_enums import FileAccessProtocol, FileStatus, FileUploadClient - - -class File(Base): - __tablename__ = "file" - _settings: ClassVar[APISettings | None] = None - _s3_client: ClassVar[S3Client | None] = None - - @staticmethod - def get_settings() -> APISettings: - if not File._settings: - raise Exception("Settings not defined in this environment") - return File._settings - - @staticmethod - def set_settings(settings: APISettings) -> None: - File._settings = settings - - @staticmethod - def get_s3_client() -> S3Client: - if not File._s3_client: - raise Exception("S3 client not defined in this environment") - return File._s3_client - - @staticmethod - def set_s3_client(s3_client: S3Client) -> None: - File._s3_client = s3_client - - id: Column[uuid.UUID] = Column(UUID(as_uuid=True), primary_key=True, default=uuid6.uuid7) - - # TODO - the relationship between Entities and Files is currently being - # configured in both directions: entities have {fieldname}_file_id fields, - # *and* files have {entity_id, field_name} fields to map back to - # entities. We'll probably deprecate one side of this relationship in - # the future, but I'm not sure yet which one is going to prove to be - # more useful. - entity_id = mapped_column(ForeignKey("entity.id")) - entity_field_name: Mapped[str] = mapped_column(String, nullable=False) - entity: Mapped[Entity] = relationship(Entity, foreign_keys=entity_id) - - # TODO: Changes here need to be reflected in graphql_api/files.py - status: Mapped[FileStatus] = mapped_column(Enum(FileStatus, native_enum=False), nullable=False) - protocol: Mapped[FileAccessProtocol] = mapped_column(Enum(FileAccessProtocol, native_enum=False), nullable=False) - namespace: Mapped[str] = mapped_column(String, nullable=False) - path: Mapped[str] = mapped_column(String, nullable=False) - file_format: Mapped[str] = mapped_column(String, nullable=False) - compression_type: Mapped[str] = mapped_column(String, nullable=True) - size: Mapped[int] = mapped_column(Integer, nullable=True) - upload_client: Mapped[FileUploadClient] = mapped_column(Enum(FileUploadClient, native_enum=False), nullable=True) - upload_error: Mapped[str] = mapped_column(String, nullable=True) - created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - server_default=func.now(), - ) - updated_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True), nullable=True) - - -@event.listens_for(File, "before_delete") -def before_delete(mapper: Mapper, connection: Connection, target: File) -> None: - """ - Before deleting a File object, check whether we need to delete it from S3, and - make sure to scrub the foreign keys in the Entity it's associated with. - """ - table_files = target.__table__ - table_entity = target.entity.__table__ - settings = File.get_settings() - s3_client = File.get_s3_client() - - # If this file is managed by platformics, see if it needs to be deleted from S3 - if target.path.startswith(f"{settings.OUTPUT_S3_PREFIX}/") and target.protocol == FileAccessProtocol.s3: - # Is this the last File object pointing to this path? - files_pointing_to_same_path = connection.execute( - table_files.select() - .where(table_files.c.id != target.id) - .where(table_files.c.protocol == target.protocol) - .where(table_files.c.namespace == target.namespace) - .where(table_files.c.path == target.path), - ) - - # If so, delete it from S3 - if len(list(files_pointing_to_same_path)) == 0: - response = s3_client.delete_object(Bucket=target.namespace, Key=target.path) - if response["ResponseMetadata"]["HTTPStatusCode"] != 204: - raise Exception("Failed to delete file from S3") - - # Finally, scrub the foreign keys in the related Entity - values = {f"{target.entity_field_name}_id": None} - # Modifying the target.entity directly does not save changes, we need to use `connection` - connection.execute( - table_entity.update().where(table_entity.c.entity_id == target.entity_id).values(**values), # type: ignore - ) diff --git a/platformics/graphql_api/core/gql_loaders.py b/platformics/graphql_api/core/gql_loaders.py index c6fee78..66ba1d3 100644 --- a/platformics/graphql_api/core/gql_loaders.py +++ b/platformics/graphql_api/core/gql_loaders.py @@ -11,7 +11,7 @@ from platformics.graphql_api.core.query_builder import get_aggregate_db_query, get_db_query, get_db_rows from platformics.security.authorization import AuthzAction, AuthzClient, Principal -E = typing.TypeVar("E", db.File, db.Entity) # type: ignore +E = typing.TypeVar("E") T = typing.TypeVar("T") diff --git a/platformics/graphql_api/core/query_builder.py b/platformics/graphql_api/core/query_builder.py index 2a98b52..ff87762 100644 --- a/platformics/graphql_api/core/query_builder.py +++ b/platformics/graphql_api/core/query_builder.py @@ -20,7 +20,7 @@ from platformics.graphql_api.core.query_input_types import aggregator_map, operator_map, orderBy from platformics.security.authorization import AuthzAction, AuthzClient, Principal -E = typing.TypeVar("E", db.File, db.Entity) +E = typing.TypeVar("E") T = typing.TypeVar("T") @@ -100,7 +100,7 @@ def convert_where_clauses_to_sql( # Unless deleted_at is explicitly set in the where clause OR we are performing a DELETE action, # we should only return rows where deleted_at is null. This is to ensure that we don't return soft-deleted rows. # Don't do this for files, since they don't have a deleted_at field. - if "deleted_at" not in local_where_clauses and action != AuthzAction.DELETE and sa_model.__name__ != "File": + if "deleted_at" not in local_where_clauses and action != AuthzAction.DELETE: local_where_clauses["deleted_at"] = {"_is_null": True} for group in group_by: # type: ignore col = strcase.to_snake(group.name) diff --git a/platformics/graphql_api/files.py b/platformics/graphql_api/files.py deleted file mode 100644 index 6592cc5..0000000 --- a/platformics/graphql_api/files.py +++ /dev/null @@ -1,613 +0,0 @@ -""" -GraphQL types, queries, and mutations for files. -""" - -import datetime -import json -import tempfile -import typing -import uuid -from dataclasses import dataclass - -import sqlalchemy as sa -import strawberry -import uuid6 -from fastapi import Depends -from mypy_boto3_s3.client import S3Client -from mypy_boto3_sts.client import STSClient -from sqlalchemy.exc import NoResultFound -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql import func -from strawberry.scalars import JSON -from strawberry.types import Info -from typing_extensions import TypedDict - -import platformics.database.models as db -from platformics.graphql_api.core.deps import ( - get_authz_client, - get_db_session, - get_s3_client, - get_settings, - get_sts_client, - require_auth_principal, - require_system_user, -) -from platformics.graphql_api.core.errors import PlatformicsError -from platformics.graphql_api.core.query_builder import get_db_rows -from platformics.graphql_api.core.query_input_types import ( - EnumComparators, - IntComparators, - StrComparators, - UUIDComparators, -) -from platformics.graphql_api.core.strawberry_extensions import DependencyExtension -from platformics.graphql_api.types.entities import Entity -from platformics.security.authorization import AuthzAction, AuthzClient, Principal -from platformics.settings import APISettings -from platformics.support import sqlalchemy_helpers -from platformics.support.file_enums import FileAccessProtocol, FileStatus -from platformics.support.format_handlers import get_validator - -FILE_TEMPORARY_PREFIX = "tmp" -FILE_CONCATENATION_MAX = 200 -FILE_CONCATENATION_MAX_SIZE = 50e3 # SARS-CoV-2 genome is ~30kbp -FILE_CONCATENATION_PREFIX = f"{FILE_TEMPORARY_PREFIX}/concatenated-files" -FILE_CONTENTS_MAX_SIZE = 1e6 # 1MB -UPLOADS_PREFIX = "uploads" - -# ------------------------------------------------------------------------------ -# Utility types/inputs -# ------------------------------------------------------------------------------ - - -@strawberry.type -@dataclass -class SignedURL: - """ - Signed URLs for downloading a file from S3. - """ - - url: str - protocol: str - method: str - expiration: int - fields: typing.Optional[JSON] = None # type: ignore - - -@strawberry.type -@dataclass -class MultipartUploadCredentials: - """ - STS token for uploading a file to S3. - """ - - protocol: str - namespace: str - path: str - access_key_id: str - secret_access_key: str - session_token: str - expiration: str - - -# Define graphQL input types so we can pass a "file" JSON to mutations. -# Keep them separate so we can control which fields are required. -@strawberry.input() -class FileUpload: - """ - GraphQL input type for uploading a file. - """ - - name: str - file_format: str - compression_type: typing.Optional[str] = None - - -@strawberry.input() -class FileCreate: - """ - GraphQL input type for creating a File object based on an existing S3 file (no upload). - """ - - name: str - file_format: str - compression_type: typing.Optional[str] = None - protocol: FileAccessProtocol - namespace: str - path: str - - -# ------------------------------------------------------------------------------ -# Data loader for fetching a File's related entity -# ------------------------------------------------------------------------------ - - -@strawberry.input -class EntityWhereClause(TypedDict): - """ - Supported where clause fields for the Entity type. - """ - - id: UUIDComparators | None - entity_id: typing.Optional[UUIDComparators] | None - producing_run_id: IntComparators | None - owner_user_id: IntComparators | None - collection_id: IntComparators | None - - -@strawberry.field(extensions=[DependencyExtension()]) -async def load_entities( - root: "File", - info: Info, - where: EntityWhereClause | None = None, -) -> typing.Optional[typing.Annotated["Entity", strawberry.lazy("platformics.graphql_api.types.entities")]]: - """ - Dataloader to fetch related entities, given file IDs. - """ - dataloader = info.context["sqlalchemy_loader"] - relationship = sqlalchemy_helpers.get_relationship(db.File, "entity") - return await dataloader.loader_for(relationship, where).load(root.entity_id) # type:ignore - - -# ------------------------------------------------------------------------------ -# Main types/inputs -# ------------------------------------------------------------------------------ - - -@strawberry.type -class File: - """ - GraphQL File type and fields. - """ - - id: strawberry.ID - entity_id: strawberry.ID - entity_field_name: str - entity: typing.Optional[typing.Annotated["Entity", strawberry.lazy("platformics.graphql_api.types.entities")]] = ( - load_entities - ) - status: FileStatus - protocol: FileAccessProtocol - namespace: str - path: str - file_format: str - compression_type: typing.Optional[int] = None - size: typing.Optional[int] = None - upload_error: typing.Optional[str] = None - created_at: datetime.datetime - updated_at: typing.Optional[datetime.datetime] = None - - @strawberry.field(extensions=[DependencyExtension()]) - def download_link( - self, - expiration: int = 3600, - s3_client: S3Client = Depends(get_s3_client), - ) -> typing.Optional[SignedURL]: - """ - Generate a signed URL for downloading a file from S3. - """ - if not self.path: # type: ignore - return None - params = {"Bucket": self.namespace, "Key": self.path} - url = s3_client.generate_presigned_url(ClientMethod="get_object", Params=params, ExpiresIn=expiration) - return SignedURL(url=url, protocol="https", method="get", expiration=expiration) - - @strawberry.field(extensions=[DependencyExtension()]) - def contents( - self, - s3_client: S3Client = Depends(get_s3_client), - ) -> str | None: - """ - Utility function to get file contents of small files. - """ - if not self.path or not self.size: - return None - if self.size > FILE_CONTENTS_MAX_SIZE: - raise Exception(f"Cannot download files larger than {FILE_CONTENTS_MAX_SIZE} bytes") - contents = s3_client.get_object(Bucket=self.namespace, Key=self.path)["Body"].read().decode("utf-8") - return contents - - -@strawberry.type -class MultipartUploadResponse: - """ - Return type for the uploadFile mutation. - """ - - credentials: MultipartUploadCredentials - file: File - - -@strawberry.input -class FileWhereClause(TypedDict): - """ - Supported where clause fields for the File type. - """ - - id: typing.Optional[UUIDComparators] - entity_id: typing.Optional[UUIDComparators] - entity_field_name: typing.Optional[StrComparators] - status: typing.Optional[EnumComparators[FileStatus]] - protocol: typing.Optional[StrComparators] - namespace: typing.Optional[StrComparators] - path: typing.Optional[StrComparators] - file_format: typing.Optional[StrComparators] - compression_type: typing.Optional[StrComparators] - size: typing.Optional[IntComparators] - - -@strawberry.field(extensions=[DependencyExtension()]) -async def resolve_files( - session: AsyncSession = Depends(get_db_session, use_cache=False), - authz_client: AuthzClient = Depends(get_authz_client), - principal: Principal = Depends(require_auth_principal), - where: typing.Optional[FileWhereClause] = None, -) -> typing.Sequence[File]: - """ - Handles files {} GraphQL queries. - """ - rows = await get_db_rows(db.File, session, authz_client, principal, where) - return rows # type: ignore - - -# ------------------------------------------------------------------------------ -# Utilities -# ------------------------------------------------------------------------------ - - -async def validate_file( - file: db.File, - session: AsyncSession = Depends(get_db_session, use_cache=False), - s3_client: S3Client = Depends(get_s3_client), -) -> None: - """ - Utility function to validate a file against its file format. - """ - validator = get_validator(format=file.file_format) - - # Validate data - try: - validator(s3_client, file.namespace, file.path).validate() - - file_size = s3_client.head_object(Bucket=file.namespace, Key=file.path)["ContentLength"] - except: # noqa - file.status = FileStatus.FAILED - else: - file.status = FileStatus.SUCCESS - file.size = file_size - - file.updated_at = func.now() - await session.commit() - - -def generate_multipart_upload_token( - new_file: db.File, - expiration: int = 3600, - sts_client: STSClient = Depends(get_sts_client), -) -> MultipartUploadCredentials: - """ - Utility function to generate an STS token for multipart upload. - """ - policy = { - "Version": "2012-10-17", - "Statement": [ - { - "Sid": "AllowSampleUploads", - "Effect": "Allow", - "Action": [ - "s3:GetObject", - "s3:PutObject", - "s3:CreateMultipartUpload", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploadParts", - ], - "Resource": f"arn:aws:s3:::{new_file.namespace}/{new_file.path}", - }, - ], - } - - # Generate an STS token to allow users to - token_name = f"file-upload-token-{uuid6.uuid7()}" - creds = sts_client.get_federation_token(Name=token_name, Policy=json.dumps(policy), DurationSeconds=expiration) - - return MultipartUploadCredentials( - protocol="s3", - namespace=new_file.namespace, - path=new_file.path, - access_key_id=creds["Credentials"]["AccessKeyId"], - secret_access_key=creds["Credentials"]["SecretAccessKey"], - session_token=creds["Credentials"]["SessionToken"], - expiration=creds["Credentials"]["Expiration"].isoformat(), - ) - - -# ------------------------------------------------------------------------------ -# Mutations -# ------------------------------------------------------------------------------ - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def mark_upload_complete( - file_id: strawberry.ID, - principal: Principal = Depends(require_auth_principal), - authz_client: AuthzClient = Depends(get_authz_client), - session: AsyncSession = Depends(get_db_session, use_cache=False), - s3_client: S3Client = Depends(get_s3_client), -) -> db.File: - """ - Once a file is uploaded, the front-end should make a markUploadComplete mutation - to mark the file as ready for pipeline analysis. - """ - - # Get the type of entity that the file is related to - try: - file_row = (await session.execute(sa.select(db.File).where(db.File.id == file_id))).scalars().one() - except NoResultFound: - raise PlatformicsError("Unauthorized: cannot update file") from None - - # Fetch the entity if have access to it - entity_class, entity = await get_entity_by_id( - session, - authz_client, - principal, - AuthzAction.UPDATE, - file_row.entity_id, - ) - - # See if we actually have access to that file. - query = authz_client.get_resource_query( - principal, - AuthzAction.UPDATE, - db.File, - ) - query = query.filter(db.File.id == file_id) - file = (await session.execute(query)).scalars().one() - if not file: - raise Exception("NOT FOUND!") # TODO: How do we raise sane errors in our api? - - await validate_file(file, session, s3_client) - return file - - -# Need to create separate mutations because they return different types. -# Strawberry is unhappy with a mutation returning a union type. -@strawberry.mutation(extensions=[DependencyExtension()]) -async def create_file( - entity_id: strawberry.ID, - entity_field_name: str, - file: FileCreate, - session: AsyncSession = Depends(get_db_session, use_cache=False), - authz_client: AuthzClient = Depends(get_authz_client), - principal: Principal = Depends(require_auth_principal), - s3_client: S3Client = Depends(get_s3_client), - sts_client: STSClient = Depends(get_sts_client), - settings: APISettings = Depends(get_settings), -) -> db.File: - """ - Create a file object based on an existing S3 file (no upload). - """ - # Since user can specify an arbitrary path, make sure only a system user can do this. - require_system_user(principal) - new_file = await create_or_upload_file( - entity_id, - entity_field_name, - file, - -1, - session, - authz_client, - principal, - s3_client, - sts_client, - settings, - ) - assert isinstance(new_file, db.File) # reassure mypy that we're returning the right type - return new_file - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def upload_file( - entity_id: strawberry.ID, - entity_field_name: str, - file: FileUpload, - expiration: int = 3600, - session: AsyncSession = Depends(get_db_session, use_cache=False), - authz_client: AuthzClient = Depends(get_authz_client), - principal: Principal = Depends(require_auth_principal), - s3_client: S3Client = Depends(get_s3_client), - sts_client: STSClient = Depends(get_sts_client), - settings: APISettings = Depends(get_settings), -) -> MultipartUploadResponse: - """ - Create a file object and generate an STS token for multipart upload. - """ - response = await create_or_upload_file( - entity_id, - entity_field_name, - file, - expiration, - session, - authz_client, - principal, - s3_client, - sts_client, - settings, - ) - assert isinstance(response, MultipartUploadResponse) # reassure mypy that we're returning the right type - return response - - -async def get_entity_by_id( - session: AsyncSession, - authz_client: AuthzClient, - principal: Principal, - action: AuthzAction, - entity_id: strawberry.ID, -) -> tuple[typing.Type[db.Base], db.Base]: - # Fetch the entity if have access to it - try: - entity_row = (await session.execute(sa.select(db.Entity).where(db.Entity.id == entity_id))).scalars().one() - entity_class = sqlalchemy_helpers.get_orm_class_by_name(type(entity_row).__name__) - except NoResultFound: - raise PlatformicsError("Unauthorized: cannot create file") from None - - query = authz_client.get_resource_query(principal, action, entity_class) - query = query.filter(entity_class.entity_id == entity_id) - try: - entity = (await session.execute(query)).scalars().one() - except NoResultFound: - raise PlatformicsError("Unauthorized: cannot create file") from None - return entity_class, entity - - -async def create_or_upload_file( - entity_id: strawberry.ID, - entity_field_name: str, - file: FileCreate | FileUpload, - expiration: int = 3600, - session: AsyncSession = Depends(get_db_session, use_cache=False), - authz_client: AuthzClient = Depends(get_authz_client), - principal: Principal = Depends(require_auth_principal), - s3_client: S3Client = Depends(get_s3_client), - sts_client: STSClient = Depends(get_sts_client), - settings: APISettings = Depends(get_settings), -) -> db.File | MultipartUploadResponse: - """ - Utility function for creating a File object, whether for upload or for linking existing files. - """ - # Basic validation - if "/" in file.name: - raise Exception("File name should not contain /") - - # Fetch the entity if have access to it - entity_class, entity = await get_entity_by_id(session, authz_client, principal, AuthzAction.UPDATE, entity_id) - - # Does that entity type have a column for storing a file ID? - entity_property_name = f"{entity_field_name}_id" - if not hasattr(entity, entity_property_name): - raise Exception(f"This entity does not have a corresponding file of type {entity_field_name}") - - # Unlink the File(s) currently connected to this entity (only commit to DB once add the new File below) - if getattr(entity, entity_property_name): - query = authz_client.get_resource_query( - principal, - AuthzAction.UPDATE, - db.File, - ) - query = query.filter(db.File.entity_id == entity_id) - query = query.filter(db.File.entity_field_name == entity_field_name) - current_files = (await session.execute(query)).scalars().all() - for current_file in current_files: - current_file.entity_id = None - - # Set file parameters based on user inputs - file_id = uuid6.uuid7() - if isinstance(file, FileUpload): - file_protocol = settings.DEFAULT_UPLOAD_PROTOCOL - file_namespace = settings.DEFAULT_UPLOAD_BUCKET - file_path = f"{settings.OUTPUT_S3_PREFIX}/{UPLOADS_PREFIX}/{file_id}/{file.name}" - else: - file_protocol = file.protocol # type: ignore - file_namespace = file.namespace - file_path = file.path - - # Create a new file record - new_file = db.File( - id=file_id, - entity_id=entity_id, - entity_field_name=entity_field_name, - protocol=file_protocol, - namespace=file_namespace, - path=file_path, - file_format=file.file_format, - compression_type=file.compression_type, - status=FileStatus.PENDING, - ) - # Save file to db first - session.add(new_file) - await session.commit() - # Then update entity with file ID (if do both in one transaction, it will fail because of foreign key constraint) - setattr(entity, entity_property_name, new_file.id) - await session.commit() - - # If file already exists, validate it - if isinstance(file, FileCreate): - await validate_file(new_file, session, s3_client) - return new_file - - # If new file, create an STS token for multipart upload - else: - return MultipartUploadResponse( - file=new_file, # type: ignore - credentials=generate_multipart_upload_token(new_file, expiration, sts_client), - ) - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def upload_temporary_file( - expiration: int = 3600, - principal: Principal = Depends(require_auth_principal), - sts_client: STSClient = Depends(get_sts_client), - settings: APISettings = Depends(get_settings), -) -> MultipartUploadResponse: - """ - Generate upload tokens to upload files to S3 for temporary use (e.g. when export CGs to Nextclade, user can upload a - tree file that is sent to Nextclade and then not used in the app later). Only system users can do this. - """ - require_system_user(principal) - # Create a File object in memory because that's what `generate_multipart_upload_token` expects - # but this doesn't create a row in the File table. - new_file = db.File(namespace=settings.DEFAULT_UPLOAD_BUCKET, path=f"{FILE_TEMPORARY_PREFIX}/{uuid6.uuid7()}") - return MultipartUploadResponse( - file=new_file, # type: ignore - credentials=generate_multipart_upload_token(new_file, expiration, sts_client), - ) - - -@strawberry.mutation(extensions=[DependencyExtension()]) -async def concatenate_files( - ids: typing.Sequence[uuid.UUID], - session: AsyncSession = Depends(get_db_session, use_cache=False), - authz_client: AuthzClient = Depends(get_authz_client), - principal: Principal = Depends(require_auth_principal), - s3_client: S3Client = Depends(get_s3_client), - settings: APISettings = Depends(get_settings), -) -> SignedURL: - """ - Concatenate file contents synchronously (as opposed to the asynchronous bulk-download concatenate pipeline). - Only use for small files e.g. for exporting small CG FASTAs to Nextclade, where input file IDs is an array of - ConsensusGenome.sequence.id. We only support doing so on SARS-CoV-2 FASTAs (~30kbp genome) - so it's ok to do synchronously. - """ - if len(ids) > FILE_CONCATENATION_MAX: - raise Exception(f"Cannot concatenate more than {FILE_CONCATENATION_MAX} files") - - # Get files in question if have access to them - where = {"id": {"_in": ids}, "status": {"_eq": FileStatus.SUCCESS}} - files = await get_db_rows(db.File, session, authz_client, principal, where) - if len(files) < 2: - raise Exception("Need at least 2 valid files to concatenate") - for file in files: - if file.size > FILE_CONCATENATION_MAX_SIZE: - raise Exception("Cannot concatenate files larger than 1MB") - - # Concatenate files (tmp files are automatically deleted when closed) - with tempfile.NamedTemporaryFile() as file_concatenated: - with open(file_concatenated.name, "ab") as fp_concat: # noqa: ASYNC101 - for file in files: - # Download file locally and append it - with tempfile.NamedTemporaryFile() as file_temp: - s3_client.download_file(file.namespace, file.path, file_temp.name) - with open(file_temp.name, "rb") as fp_temp: # noqa: ASYNC101 - fp_concat.write(fp_temp.read()) - # Upload to S3 - path = f"{FILE_CONCATENATION_PREFIX}/{uuid6.uuid7()}" - s3_client.upload_file(file_concatenated.name, file.namespace, path) - - # Return signed URL - expiration = 36000 - url = s3_client.generate_presigned_url( - ClientMethod="get_object", - Params={"Bucket": settings.DEFAULT_UPLOAD_BUCKET, "Key": path}, - ExpiresIn=expiration, - ) - return SignedURL(url=url, protocol="https", method="get", expiration=expiration) diff --git a/platformics/graphql_api/setup.py b/platformics/graphql_api/setup.py index e469958..c47c13f 100644 --- a/platformics/graphql_api/setup.py +++ b/platformics/graphql_api/setup.py @@ -11,7 +11,6 @@ from strawberry.schema.name_converter import HasGraphQLName, NameConverter from platformics.database.connect import AsyncDB -from platformics.database.models.file import File from platformics.graphql_api.core.deps import ( get_auth_principal, get_authz_client, @@ -64,8 +63,6 @@ def get_app( """ Make sure tests can get their own instances of the app. """ - File.set_settings(settings) - File.set_s3_client(get_s3_client(settings)) title = settings.SERVICE_NAME graphql_app: GraphQLRouter = GraphQLRouter(schema, context_getter=get_context) diff --git a/platformics/security/authorization.py b/platformics/security/authorization.py index 72716cc..6ff4e69 100644 --- a/platformics/security/authorization.py +++ b/platformics/security/authorization.py @@ -116,14 +116,9 @@ def get_resource_query( attr_map = {} joins = [] - if model_cls == db.File: # type: ignore - for col in sqlalchemy_helpers.model_class_cols(db.Entity): - attr_map[f"request.resource.attr.{col.key}"] = getattr(db.Entity, col.key) - joins = [(db.Entity, db.File.entity_id == db.Entity.id)] # type: ignore - else: - # Send all non-relationship columns to cerbos to make decisions - for col in sqlalchemy_helpers.model_class_cols(model_cls): - attr_map[f"request.resource.attr.{col.key}"] = getattr(model_cls, col.key) + # Send all non-relationship columns to cerbos to make decisions + for col in sqlalchemy_helpers.model_class_cols(model_cls): + attr_map[f"request.resource.attr.{col.key}"] = getattr(model_cls, col.key) query = get_query( plan, model_cls, # type: ignore diff --git a/platformics/support/file_enums.py b/platformics/support/file_enums.py deleted file mode 100644 index f532424..0000000 --- a/platformics/support/file_enums.py +++ /dev/null @@ -1,24 +0,0 @@ -import enum - -import strawberry - - -@strawberry.enum -class FileStatus(enum.Enum): - SUCCESS = "SUCCESS" - FAILED = "FAILED" - PENDING = "PENDING" - - -@strawberry.enum -class FileAccessProtocol(enum.Enum): - s3 = "s3" - https = "https" - - -@strawberry.enum -class FileUploadClient(enum.Enum): - browser = "browser" - cli = "cli" - s3 = "s3" - basespace = "basespace" diff --git a/platformics/support/format_handlers.py b/platformics/support/format_handlers.py deleted file mode 100644 index 83ef53a..0000000 --- a/platformics/support/format_handlers.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Logic to validate that a file of a certain format is valid -""" - -import gzip -import io -import json -from abc import abstractmethod -from typing import Protocol - -from Bio import SeqIO -from mypy_boto3_s3.client import S3Client - - -class FileFormatHandler(Protocol): - """ - Interface for a file format handler - """ - - s3client: S3Client - bucket: str - key: str - - def __init__(self, s3client: S3Client, bucket: str, key: str): - self.s3client = s3client - self.bucket = bucket - self.key = key - - def contents(self) -> str: - """ - Get the contents of the file - """ - body = self.s3client.get_object(Bucket=self.bucket, Key=self.key, Range="bytes=0-1000000")["Body"] - if self.key.endswith(".gz"): - with gzip.GzipFile(fileobj=body) as fp: - return fp.read().decode("utf-8") - return body.read().decode("utf-8") - - @abstractmethod - def validate(self) -> None: - raise NotImplementedError - - -class FastaHandler(FileFormatHandler): - """ - Validate FASTA files. Note that even truncated FASTA files are supported: - ">" is a valid FASTA file, and so is ">abc" (without a sequence). - """ - - def validate(self) -> None: - sequences = 0 - for _ in SeqIO.parse(io.StringIO(self.contents()), "fasta"): - sequences += 1 - assert sequences > 0 - - -class FastqHandler(FileFormatHandler): - """ - Validate FASTQ files. Can't use biopython directly because large file would be truncated. - This removes truncated FASTQ records by assuming 1 read = 4 lines. - """ - - def validate(self) -> None: - # Load file and only keep non-truncated FASTQ records (4 lines per record) - fastq = self.contents().split("\n") - fastq = fastq[: len(fastq) - (len(fastq) % 4)] - - # Validate it with SeqIO - reads = 0 - for _ in SeqIO.parse(io.StringIO("\n".join(fastq)), "fastq"): - reads += 1 - assert reads > 0 - - -class BedHandler(FileFormatHandler): - """ - Validate BED files using basic checks. - """ - - def validate(self) -> None: - # Ignore last line since it could be truncated - records = self.contents().split("\n")[:-1] - assert len(records) > 0 - - # BED files must have at least 3 columns - error out if the file incorrectly uses spaces instead of tabs - num_cols = -1 - for record in records: - assert len(record.split("\t")) >= 3 - # All rows should have the same number of columns - if num_cols == -1: - num_cols = len(record.split("\t")) - else: - assert num_cols == len(record.split("\t")) - - -class JsonHandler(FileFormatHandler): - """ - Validate JSON files - """ - - def validate(self) -> None: - json.loads(self.contents()) # throws an exception for invalid JSON - - -class ZipHandler(FileFormatHandler): - """ - Validate ZIP files - """ - - def validate(self) -> None: - assert self.key.endswith(".zip") # throws an exception if the file is not a zip file - - -def get_validator(format: str) -> type[FileFormatHandler]: - """ - Returns the validator for a given file format - """ - if format in ["fa", "fasta"]: - return FastaHandler - elif format == "fastq": - return FastqHandler - elif format == "bed": - return BedHandler - elif format == "json": - return JsonHandler - elif format == "zip": - return ZipHandler - else: - raise Exception(f"Unknown file format '{format}'") diff --git a/platformics/test_infra/factories/base.py b/platformics/test_infra/factories/base.py index 58291ef..3912a80 100644 --- a/platformics/test_infra/factories/base.py +++ b/platformics/test_infra/factories/base.py @@ -1,5 +1,5 @@ """ -File factory +Test factory bases """ import factory @@ -11,19 +11,13 @@ from faker_biology.physiology import Organ from faker_enum import EnumProvider -from platformics.database.models import Entity, File, FileStatus +from platformics.database.models import Entity Faker.add_provider(Bioseq) Faker.add_provider(Organ) Faker.add_provider(EnumProvider) -def generate_relative_file_path(obj) -> str: # type: ignore - fake = faker.Faker() - # Can't use absolute=True param because that requires newer version of faker than faker-biology supports - return fake.file_path(depth=3, extension=obj.file_format).lstrip("/") - - class SessionStorage: """ TODO: this is a lame singleton to prevent this library from requiring an active SA session at import-time. We @@ -54,54 +48,3 @@ class Meta: sqlalchemy_session_factory = SessionStorage.get_session sqlalchemy_session_persistence = "commit" sqlalchemy_session = None # workaround for a bug in factoryboy - - -class FileFactory(factory.alchemy.SQLAlchemyModelFactory): - """ - Factory for generating files - """ - - class Meta: - sqlalchemy_session_factory = SessionStorage.get_session - sqlalchemy_session_persistence = "commit" - sqlalchemy_session = None # workaround for a bug in factoryboy - model = File - # What fields do we try to match to existing db rows to determine whether we - # should create a new row or not? - sqlalchemy_get_or_create = ("namespace", "path") - - status = factory.Faker("enum", enum_cls=FileStatus) - protocol = "s3" - namespace = fuzzy.FuzzyChoice(["local-bucket", "remote-bucket"]) - path = factory.LazyAttribute(lambda o: generate_relative_file_path(o)) - file_format = fuzzy.FuzzyChoice(["fasta", "fastq", "bam"]) - compression_type = fuzzy.FuzzyChoice(["gz", "bz2", "xz"]) - size = fuzzy.FuzzyInteger(1024, 1024 * 1024 * 1024) # Between 1k and 1G - upload_client = fuzzy.FuzzyChoice(["browser", "cli", "s3", "basespace"]) - - @classmethod - def update_file_ids(cls) -> None: - """ - Function used by tests after creating entities to link files to entities - e.g. for SequencingRead, sets SequencingRead.r1_file_id = File.id - """ - session = SessionStorage.get_session() - if not session: - raise Exception("No session found") - # For each file, find the entity associated with it - # and update the file_id for that entity. - files = session.query(File).all() - for file in files: - if file.entity_id: - entity_field_name = file.entity_field_name - entity = session.query(Entity).filter(Entity.id == file.entity_id).first() - if entity: - entity_name = entity.type - session.execute( - sa.text( - f"""UPDATE {entity_name} SET {entity_field_name}_id = file.id - FROM file WHERE {entity_name}.entity_id = file.entity_id and file.entity_field_name = :field_name""", - ), - {"field_name": entity_field_name}, - ) - session.commit() diff --git a/platformics/test_infra/main.py b/platformics/test_infra/main.py deleted file mode 100644 index a3ce5ec..0000000 --- a/platformics/test_infra/main.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -File factory -""" - -import factory -import faker -import sqlalchemy as sa -import uuid6 -from factory import Faker, fuzzy -from faker_biology.bioseq import Bioseq -from faker_biology.physiology import Organ -from faker_enum import EnumProvider - -from platformics.database.models import Entity, File, FileStatus - -Faker.add_provider(Bioseq) -Faker.add_provider(Organ) -Faker.add_provider(EnumProvider) - - -def generate_relative_file_path(obj) -> str: # type: ignore - fake = faker.Faker() - # Can't use absolute=True param because that requires newer version of faker than faker-biology supports - return fake.file_path(depth=3, extension=obj.file_format).lstrip("/") - - -class SessionStorage: - """ - TODO: this is a lame singleton to prevent this library from requiring an active SA session at import-time. We - should try to refactor it out when we know more about factoryboy - """ - - session = None - - @classmethod - def set_session(cls, session: sa.orm.Session) -> None: - cls.session = session - - @classmethod - def get_session(cls) -> sa.orm.Session | None: - return cls.session - - -class CommonFactory(factory.alchemy.SQLAlchemyModelFactory): - """ - Base class for all factories - """ - - owner_user_id = fuzzy.FuzzyInteger(1, 1000) - collection_id = fuzzy.FuzzyInteger(1, 1000) - entity_id = uuid6.uuid7() # needed so we can set `sqlalchemy_get_or_create` = entity_id in other factories - - class Meta: - sqlalchemy_session_factory = SessionStorage.get_session - sqlalchemy_session_persistence = "commit" - sqlalchemy_session = None # workaround for a bug in factoryboy - - -class FileFactory(factory.alchemy.SQLAlchemyModelFactory): - """ - Factory for generating files - """ - - class Meta: - sqlalchemy_session_factory = SessionStorage.get_session - sqlalchemy_session_persistence = "commit" - sqlalchemy_session = None # workaround for a bug in factoryboy - model = File - # What fields do we try to match to existing db rows to determine whether we - # should create a new row or not? - sqlalchemy_get_or_create = ("namespace", "path") - - status = factory.Faker("enum", enum_cls=FileStatus) - protocol = "s3" - namespace = fuzzy.FuzzyChoice(["local-bucket", "remote-bucket"]) - path = factory.LazyAttribute(lambda o: generate_relative_file_path(o)) - file_format = fuzzy.FuzzyChoice(["fasta", "fastq", "bam"]) - compression_type = fuzzy.FuzzyChoice(["gz", "bz2", "xz"]) - size = fuzzy.FuzzyInteger(1024, 1024 * 1024 * 1024) # Between 1k and 1G - upload_client = fuzzy.FuzzyChoice(["browser", "cli", "s3", "basespace"]) - - @classmethod - def update_file_ids(cls) -> None: - """ - Function used by tests after creating entities to link files to entities - e.g. for SequencingRead, sets SequencingRead.r1_file_id = File.id - """ - session = SessionStorage.get_session() - if not session: - raise Exception("No session found") - # For each file, find the entity associated with it - # and update the file_id for that entity. - files = session.query(File).all() - for file in files: - if file.entity_id: - entity_field_name = file.entity_field_name - entity = session.query(Entity).filter(Entity.id == file.entity_id).first() - if entity: - entity_name = entity.type - session.execute( - sa.text( - f"""UPDATE {entity_name} SET {entity_field_name}_id = file.id - FROM file WHERE {entity_name}.entity_id = file.entity_id""", - ), - ) - session.commit() diff --git a/test_app/conftest.py b/test_app/conftest.py index 7111501..5ba465f 100644 --- a/test_app/conftest.py +++ b/test_app/conftest.py @@ -26,7 +26,7 @@ from platformics.graphql_api.setup import get_app from platformics.database.connect import AsyncDB, SyncDB, init_async_db, init_sync_db from platformics.database.models.base import Base -from platformics.test_infra.factories.base import FileFactory, SessionStorage +from platformics.test_infra.factories.base import SessionStorage from pytest_postgresql import factories from pytest_postgresql.executor_noop import NoopExecutor from pytest_postgresql.janitor import DatabaseJanitor @@ -44,7 +44,6 @@ "moto_client", "GQLTestClient", "SessionStorage", - "FileFactory", ] # needed by tests diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index 6618966..22f4d5d 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -62,7 +62,7 @@ classes: producing_run_id: range: uuid annotations: - mutable: false # This field can't be modified by an `Update` mutation + mutable: false # This field can't be modified by an `Update` mutation system_writable_only: True owner_user_id: range: integer @@ -89,19 +89,12 @@ classes: plural: Entities File: + is_a: Entity + mixins: + - EntityMixin + annotations: + plural: Files attributes: - id: - identifier: true - range: uuid - required: true - # This file's ID is stored in the entity column _id - entity_field_name: - range: string - required: true - # Which entity this file is connected to - entity: - range: Entity - required: true status: range: FileStatus required: true @@ -126,7 +119,7 @@ classes: minimum_value: 0 # Information about file upload (optional) upload_client: - range: FileUploadClient + range: string upload_error: range: string created_at: @@ -190,11 +183,13 @@ classes: readonly: true annotations: cascade_delete: true + single_parent: true r2_file: range: File readonly: true annotations: cascade_delete: true + single_parent: true technology: range: SequencingTechnology required: true @@ -226,6 +221,7 @@ classes: readonly: true annotations: cascade_delete: true + single_parent: true sequencing_reads: range: SequencingRead inverse: SequencingRead.primer_file diff --git a/test_app/scripts/seed.py b/test_app/scripts/seed.py index fe252dc..04468c1 100644 --- a/test_app/scripts/seed.py +++ b/test_app/scripts/seed.py @@ -5,7 +5,7 @@ import factory.random from platformics.database.connect import init_sync_db from platformics.settings import CLISettings -from platformics.test_infra.factories.base import FileFactory, SessionStorage +from platformics.test_infra.factories.base import SessionStorage from test_infra.factories.sample import SampleFactory from test_infra.factories.sequencing_read import SequencingReadFactory @@ -29,8 +29,6 @@ def use_factoryboy() -> None: SequencingReadFactory.create_batch(2, sample=sa2, owner_user_id=sa2.owner_user_id, collection_id=sa2.collection_id) SequencingReadFactory.create_batch(3, sample=sa3, owner_user_id=sa3.owner_user_id, collection_id=sa3.collection_id) - FileFactory.update_file_ids() - session.commit() diff --git a/test_app/tests/test_cascade_deletion.py b/test_app/tests/test_cascade_deletion.py index f8f6359..5757b21 100644 --- a/test_app/tests/test_cascade_deletion.py +++ b/test_app/tests/test_cascade_deletion.py @@ -4,7 +4,7 @@ import pytest from platformics.database.connect import SyncDB -from conftest import SessionStorage, GQLTestClient, FileFactory +from conftest import SessionStorage, GQLTestClient from test_infra.factories.sequencing_read import SequencingReadFactory @@ -25,7 +25,6 @@ async def test_cascade_delete( sequencing_reads = SequencingReadFactory.create_batch( 2, technology="Illumina", owner_user_id=user_id, collection_id=project_id ) - FileFactory.update_file_ids() # Delete the first Sample query = f""" @@ -63,7 +62,7 @@ async def test_cascade_delete( # Files from the first SequencingRead should be deleted query = f""" query MyQuery {{ - files(where: {{ entityId: {{ _eq: "{sequencing_reads[0].entity_id}" }} }}) {{ + files(where: {{ id: {{ _in: ["{sequencing_reads[0].r1_file.id}", "{sequencing_reads[0].r2_file.id}", "{sequencing_reads[0].primer_file.id}"] }} }}) {{ id }} }} @@ -74,7 +73,7 @@ async def test_cascade_delete( # Files from the second SequencingRead should NOT be deleted query = f""" query MyQuery {{ - files(where: {{ entityId: {{ _eq: "{sequencing_reads[1].entity_id}" }} }}) {{ + files(where: {{ id: {{ _in: ["{sequencing_reads[1].r1_file.id}", "{sequencing_reads[1].r2_file.id}", "{sequencing_reads[1].primer_file.id}"] }} }}) {{ id }} }} diff --git a/test_app/tests/test_file_concatenation.py b/test_app/tests/test_file_concatenation.py deleted file mode 100644 index 7a75c2b..0000000 --- a/test_app/tests/test_file_concatenation.py +++ /dev/null @@ -1,100 +0,0 @@ -""" -Test concatenating small files, both plain text and gzipped -""" - -import pytest -import requests -from mypy_boto3_s3.client import S3Client -from platformics.database.connect import SyncDB -from conftest import SessionStorage, GQLTestClient -from test_infra.factories.sequencing_read import SequencingReadFactory - - -@pytest.mark.parametrize( - "file_name_1,file_name_2", [("test1.fasta", "test2.fasta"), ("test1.fasta.gz", "test2.fasta.gz")] -) -@pytest.mark.asyncio -async def test_concatenation( - file_name_1: str, - file_name_2: str, - sync_db: SyncDB, - gql_client: GQLTestClient, - moto_client: S3Client, -) -> None: - """ - Upload 2 files and concatenate them - """ - user_id = 12345 - project_id = 111 - member_projects = [project_id] - fasta_file_1 = f"tests/fixtures/{file_name_1}" - fasta_file_2 = f"tests/fixtures/{file_name_2}" - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - sequencing_read = SequencingReadFactory.create(owner_user_id=user_id, collection_id=project_id) - entity_id = sequencing_read.entity_id - session.commit() - - # Create files - mutation = f""" - mutation MyQuery {{ - r1: uploadFile( - entityId: "{entity_id}", - entityFieldName: "r1_file", - file: {{ name: "{file_name_1}", fileFormat: "fasta" }} - ) {{ - file {{ id }} - credentials {{ namespace path }} - }} - - r2: uploadFile( - entityId: "{entity_id}", - entityFieldName: "r2_file", - file: {{ name: "{file_name_2}", fileFormat: "fasta" }} - ) {{ - file {{ id }} - credentials {{ namespace path }} - }} - }} - """ - output = await gql_client.query(mutation, member_projects=member_projects) - - # Upload files - credentials_1 = output["data"]["r1"]["credentials"] - credentials_2 = output["data"]["r2"]["credentials"] - moto_client.put_object(Bucket=credentials_1["namespace"], Key=credentials_1["path"], Body=open(fasta_file_1, "rb")) - moto_client.put_object(Bucket=credentials_2["namespace"], Key=credentials_2["path"], Body=open(fasta_file_2, "rb")) - - # Mark upload as complete - file_id_1 = output["data"]["r1"]["file"]["id"] - file_id_2 = output["data"]["r2"]["file"]["id"] - query = f""" - mutation MyMutation {{ - r1: markUploadComplete(fileId: "{file_id_1}") {{ status }} - r2: markUploadComplete(fileId: "{file_id_2}") {{ status }} - }} - """ - output = await gql_client.query(query, member_projects=member_projects) - assert output["data"]["r1"]["status"] == "SUCCESS" - assert output["data"]["r2"]["status"] == "SUCCESS" - - # Concatenate files - query = f""" - mutation Concatenate {{ - concatenateFiles(ids: ["{file_id_1}", "{file_id_2}"]) {{ - url - }} - }} - """ - output = await gql_client.query(query, member_projects=member_projects) - contents_observed = requests.get(output["data"]["concatenateFiles"]["url"]).content - - # Validate concatenated files - contents_expected = b"" - with open(fasta_file_1, "rb") as f: - contents_expected += f.read() - with open(fasta_file_2, "rb") as f: - contents_expected += f.read() - assert contents_expected == contents_observed diff --git a/test_app/tests/test_file_mutations.py b/test_app/tests/test_file_mutations.py deleted file mode 100644 index 6642075..0000000 --- a/test_app/tests/test_file_mutations.py +++ /dev/null @@ -1,309 +0,0 @@ -""" -Test file mutations for upload, linking an existing S3 file, and marking a file as completed -""" - -import os -import pytest -import typing -import sqlalchemy as sa -from mypy_boto3_s3.client import S3Client -from platformics.database.connect import SyncDB -from database.models import File, FileStatus -from conftest import SessionStorage, FileFactory, GQLTestClient -from test_infra.factories.sequencing_read import SequencingReadFactory -from database.models import SequencingRead - - -@pytest.mark.asyncio -async def test_file_validation( - sync_db: SyncDB, - gql_client: GQLTestClient, - moto_client: S3Client, -) -> None: - """ - Test that we can mark a file upload as complete - """ - user1_id = 12345 - project1_id = 123 - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - SequencingReadFactory.create(owner_user_id=user1_id, collection_id=project1_id) - FileFactory.update_file_ids() - session.commit() - files = session.execute(sa.select(File)).scalars().all() - file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] - - valid_fastq_file = "tests/fixtures/test1.fastq" - file_size = os.stat(valid_fastq_file).st_size - moto_client.put_object(Bucket=file.namespace, Key=file.path, Body=open(valid_fastq_file, "rb")) - - # Mark upload complete - query = f""" - mutation MyMutation {{ - markUploadComplete(fileId: "{file.id}") {{ - id - namespace - size - status - }} - }} - """ - res = await gql_client.query(query, member_projects=[project1_id]) - fileinfo = res["data"]["markUploadComplete"] - assert fileinfo["status"] == "SUCCESS" - assert fileinfo["size"] == file_size - - # Make sure the file was updated in the database - with sync_db.session() as session: - files = session.execute(sa.select(File)).scalars().all() - file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] - assert file.status == FileStatus.SUCCESS - assert file.size == file_size - - -@pytest.mark.asyncio -async def test_invalid_fastq( - sync_db: SyncDB, - gql_client: GQLTestClient, - moto_client: S3Client, -) -> None: - """ - Test that invalid fastq's don't work - """ - user1_id = 12345 - project1_id = 123 - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - SequencingReadFactory.create(owner_user_id=user1_id, collection_id=project1_id) - FileFactory.update_file_ids() - session.commit() - files = session.execute(sa.select(File)).scalars().all() - file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] - - moto_client.put_object(Bucket=file.namespace, Key=file.path, Body="this is not a fastq file") - - # Mark upload complete - query = f""" - mutation MyMutation {{ - markUploadComplete(fileId: "{file.id}") {{ - id - namespace - size - status - }} - }} - """ - res = await gql_client.query(query, member_projects=[project1_id]) - fileinfo = res["data"]["markUploadComplete"] - assert fileinfo["status"] == "FAILED" - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "member_projects,project_id,entity_field", - [ - ([456], 123, "r1_file"), # Can't create file for entity you don't have access to - ([123], 123, "does_not_exist"), # Can't create file for entity that isn't connected to a valid file type - ([123], 123, "r1_file"), # Can create file for entity you have access to - ], -) -async def test_upload_file( - member_projects: list[int], - project_id: int, - entity_field: str, - sync_db: SyncDB, - gql_client: GQLTestClient, -) -> None: - """ - Test generating STS tokens for file uploads - """ - user_id = 12345 - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - SequencingReadFactory.create(owner_user_id=user_id, collection_id=project_id) - FileFactory.update_file_ids() - session.commit() - - sequencing_read = session.execute(sa.select(SequencingRead)).scalars().one() - entity_id = sequencing_read.entity_id - - # Try creating a file - mutation = f""" - mutation MyQuery {{ - uploadFile( - entityId: "{entity_id}", - entityFieldName: "{entity_field}", - file: {{ - name: "test.fastq", - fileFormat: "fastq" - }} - ) {{ - credentials {{ - namespace - path - accessKeyId - secretAccessKey - expiration - }} - }} - }} - """ - output = await gql_client.query(mutation, member_projects=member_projects) - - # If don't have access to this entity, or trying to link an entity with a made up file type, should get an error - if project_id not in member_projects or entity_field == "does_not_exist": - assert output["data"] is None - assert output["errors"] is not None - return - - # Moto produces hard-coded tokens - assert output["data"]["uploadFile"]["credentials"]["accessKeyId"].endswith("EXAMPLE") - assert output["data"]["uploadFile"]["credentials"]["secretAccessKey"].endswith("EXAMPLEKEY") - - -@pytest.mark.asyncio -async def test_create_file( - sync_db: SyncDB, - gql_client: GQLTestClient, - moto_client: S3Client, -) -> None: - """ - Test adding an existing file to the entities service - """ - with sync_db.session() as session: - # Create sequencing read and file - SessionStorage.set_session(session) - SequencingReadFactory.create(owner_user_id=12345, collection_id=123) - FileFactory.update_file_ids() - session.commit() - - sequencing_read = session.execute(sa.select(SequencingRead)).scalars().one() - entity_id = sequencing_read.entity_id - - # Upload a fastq file to a mock bucket so we can create a file object from it - file_namespace = "local-bucket" - file_path = "test1.fastq" - file_path_local = "tests/fixtures/test1.fastq" - file_size = os.stat(file_path_local).st_size - with open(file_path_local, "rb") as fp: - moto_client.put_object(Bucket=file_namespace, Key=file_path, Body=fp) - - # Try creating a file from existing file on S3 - mutation = f""" - mutation MyQuery {{ - createFile( - entityId: "{entity_id}", - entityFieldName: "r1_file", - file: {{ - name: "{file_path}", - fileFormat: "fastq", - protocol: s3, - namespace: "{file_namespace}", - path: "{file_path}" - }} - ) {{ - path - size - }} - }} - """ - output = await gql_client.query(mutation, member_projects=[123], service_identity="workflows") - assert output["data"]["createFile"]["size"] == file_size - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "file_path,multiple_files_for_one_path,should_delete", - [ - ("platformics/test1.fastq", False, True), - ("bla/test1.fastq", False, False), - ("platformics/test1.fastq", True, False), - ("bla/test1.fastq", True, False), - ], -) -async def test_delete_from_s3( - file_path: str, - should_delete: bool, - multiple_files_for_one_path: bool, - sync_db: SyncDB, - gql_client: GQLTestClient, - moto_client: S3Client, - monkeypatch: typing.Any, -) -> None: - """ - Test that we delete a file from S3 under the right circumstances - """ - user1_id = 12345 - project1_id = 123 - user2_id = 67890 - project2_id = 456 - bucket = "local-bucket" - - # Patch the S3 client to make sure tests are operating on the same mock bucket - monkeypatch.setattr(File, "get_s3_client", lambda: moto_client) - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - SequencingReadFactory.create(owner_user_id=user1_id, collection_id=project1_id) - FileFactory.update_file_ids() - session.commit() - files = session.execute(sa.select(File)).scalars().all() - file = list(filter(lambda file: file.entity_field_name == "r1_file", files))[0] - file.path = file_path - file.namespace = bucket # set the bucket to make sure the mock file is in the right place - session.commit() - - # Also test the case where multiple files point to the same path - if multiple_files_for_one_path: - sequencing_read = SequencingReadFactory.create(owner_user_id=user2_id, collection_id=project2_id) - FileFactory.update_file_ids() - session.commit() - session.refresh(sequencing_read) - sequencing_read.r1_file.path = file_path - sequencing_read.r1_file.namespace = bucket - session.commit() - - valid_fastq_file = "tests/fixtures/test1.fastq" - moto_client.put_object(Bucket=file.namespace, Key=file.path, Body=open(valid_fastq_file, "rb")) - - # Delete SequencingRead and cascade to File objects - query = f""" - mutation MyMutation {{ - deleteSequencingRead(where: {{ id: {{ _eq: "{file.entity_id}" }} }}) {{ - id - }} - }} - """ - - # File should exist on S3 before the deletion - assert "Contents" in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) - - # Issue deletion - result = await gql_client.query( - query, user_id=user1_id, member_projects=[project1_id], service_identity="workflows" - ) - assert result["data"]["deleteSequencingRead"][0]["id"] == str(file.entity_id) - - # Make sure file either does or does not exist - if should_delete: - assert "Contents" not in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) - else: - assert "Contents" in moto_client.list_objects(Bucket=file.namespace, Prefix=file.path) - - # Make sure File object doesn't exist either - query = f""" - query MyQuery {{ - files(where: {{ id: {{ _eq: "{file.id}" }} }}) {{ - id - }} - }} - """ - result = await gql_client.query(query, user_id=user1_id, member_projects=[project1_id]) - assert result["data"]["files"] == [] diff --git a/test_app/tests/test_file_queries.py b/test_app/tests/test_file_queries.py deleted file mode 100644 index 1488b51..0000000 --- a/test_app/tests/test_file_queries.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Test file queries -""" - -import pytest -from conftest import FileFactory, GQLTestClient, SessionStorage -from test_infra.factories.sequencing_read import SequencingReadFactory -from platformics.database.connect import SyncDB - - -@pytest.mark.asyncio -async def test_file_query( - sync_db: SyncDB, - gql_client: GQLTestClient, -) -> None: - """ - Test that we can only fetch files that we have access to - """ - user1_id = 12345 - user2_id = 67890 - user3_id = 87654 - project1_id = 123 - project2_id = 456 - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - SequencingReadFactory.create_batch(2, owner_user_id=user1_id, collection_id=project1_id) - SequencingReadFactory.create_batch(6, owner_user_id=user2_id, collection_id=project1_id) - SequencingReadFactory.create_batch(4, owner_user_id=user3_id, collection_id=project2_id) - FileFactory.update_file_ids() - - # Fetch all samples - query = """ - query MyQuery { - files { - entity { - collectionId - ownerUserId - id - type - } - path - entityFieldName - } - } - """ - output = await gql_client.query(query, member_projects=[project1_id]) - # Each SequencingRead results in 3 files: - # r1_file, r2_file - # primer_file -> GenomicRange file - # so we expect 8 * 3 = 24 files. - assert len(output["data"]["files"]) == 24 - for file in output["data"]["files"]: - assert file["path"] is not None - assert file["entity"]["collectionId"] == project1_id - assert file["entity"]["ownerUserId"] in (user1_id, user2_id) - - -@pytest.mark.asyncio -async def test_nested_files( - sync_db: SyncDB, - gql_client: GQLTestClient, -) -> None: - """ - Test that we can fetch related file info. - """ - user1_id = 12345 - user2_id = 67890 - user3_id = 87654 - project1_id = 123 - project2_id = 456 - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - SequencingReadFactory.create_batch(2, owner_user_id=user1_id, collection_id=project1_id) - SequencingReadFactory.create_batch(6, owner_user_id=user2_id, collection_id=project1_id) - SequencingReadFactory.create_batch(4, owner_user_id=user3_id, collection_id=project2_id) - FileFactory.update_file_ids() - - # Fetch all samples - query = """ - query MyQuery { - sequencingReads { - r1File { - entityId - fileFormat - path - size - } - nucleicAcid - id - ownerUserId - } - } - """ - output = await gql_client.query(query, member_projects=[project1_id]) - assert len(output["data"]["sequencingReads"]) == 8 - - for read in output["data"]["sequencingReads"]: - assert read["r1File"] is not None - assert read["r1File"]["entityId"] == read["id"] - assert read["r1File"]["path"] diff --git a/test_app/tests/test_file_uploads.py b/test_app/tests/test_file_uploads.py deleted file mode 100644 index fbba060..0000000 --- a/test_app/tests/test_file_uploads.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Test end-to-end upload process -""" - -import os -import pytest -from mypy_boto3_s3.client import S3Client -from platformics.database.connect import SyncDB -from conftest import SessionStorage, GQLTestClient -from test_infra.factories.sequencing_read import SequencingReadFactory - - -@pytest.mark.asyncio -async def test_upload_process( - sync_db: SyncDB, - gql_client: GQLTestClient, - moto_client: S3Client, -) -> None: - """ - Generate AWS credentials, upload a file, and mark the upload as complete - """ - user_id = 12345 - project_id = 111 - member_projects = [project_id] - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - sequencing_read = SequencingReadFactory.create(owner_user_id=user_id, collection_id=project_id) - entity_id = sequencing_read.entity_id - session.commit() - - # Get AWS creds to upload an R1 fastq file - mutation = f""" - mutation MyQuery {{ - uploadFile( - entityId: "{entity_id}", - entityFieldName: "r1_file", - file: {{ - name: "some_file.fastq", - fileFormat: "fastq" - }} - ) {{ - file {{ - id - status - }} - credentials {{ - namespace - path - accessKeyId - secretAccessKey - expiration - }} - }} - }} - """ - - output = await gql_client.query(mutation, member_projects=member_projects, user_id=user_id) - file_id = output["data"]["uploadFile"]["file"]["id"] - credentials = output["data"]["uploadFile"]["credentials"] - - # Upload the file - fastq_file = "tests/fixtures/test1.fastq" - fastq_file_size = os.stat(fastq_file).st_size - moto_client.put_object(Bucket=credentials["namespace"], Key=credentials["path"], Body=open(fastq_file, "rb")) - - # Mark upload complete - query = f""" - mutation MyMutation {{ - markUploadComplete(fileId: "{file_id}") {{ - id - namespace - size - status - }} - }} - """ - output = await gql_client.query(query, member_projects=member_projects) - assert output["data"]["markUploadComplete"]["status"] == "SUCCESS" - assert output["data"]["markUploadComplete"]["size"] == fastq_file_size - - -@pytest.mark.asyncio -async def test_upload_process_multiple_files_per_entity( - sync_db: SyncDB, - gql_client: GQLTestClient, - moto_client: S3Client, -) -> None: - """ - Make sure that entities with multiple file links still behave correctly. Test the logic that unlinks a File from - an entity if another file is uploaded to the same entity field. - """ - user_id = 12345 - project_id = 111 - member_projects = [project_id] - fastq_file = "tests/fixtures/test1.fastq" - - # Create mock data - with sync_db.session() as session: - SessionStorage.set_session(session) - sequencing_read = SequencingReadFactory.create(owner_user_id=user_id, collection_id=project_id) - entity_id = sequencing_read.entity_id - session.commit() - - # Create files - mutation = f""" - mutation MyQuery {{ - r1: uploadFile( - entityId: "{entity_id}", - entityFieldName: "r1_file", - file: {{ name: "some_file1.fastq", fileFormat: "fastq" }} - ) {{ - file {{ id }} - credentials {{ namespace path }} - }} - - r2: uploadFile( - entityId: "{entity_id}", - entityFieldName: "r2_file", - file: {{ name: "some_file2.fastq", fileFormat: "fastq" }} - ) {{ - file {{ id }} - credentials {{ namespace path }} - }} - }} - """ - output = await gql_client.query(mutation, member_projects=member_projects) - - # Upload files - credentials_1 = output["data"]["r1"]["credentials"] - credentials_2 = output["data"]["r2"]["credentials"] - moto_client.put_object(Bucket=credentials_1["namespace"], Key=credentials_1["path"], Body=open(fastq_file, "rb")) - moto_client.put_object(Bucket=credentials_2["namespace"], Key=credentials_2["path"], Body=open(fastq_file, "rb")) - - # Mark upload as complete - file_id_1 = output["data"]["r1"]["file"]["id"] - file_id_2 = output["data"]["r2"]["file"]["id"] - query = f""" - mutation MyMutation {{ - r1: markUploadComplete(fileId: "{file_id_1}") {{ status }} - r2: markUploadComplete(fileId: "{file_id_2}") {{ status }} - }} - """ - output = await gql_client.query(query, member_projects=member_projects) - assert output["data"]["r1"]["status"] == "SUCCESS" - assert output["data"]["r2"]["status"] == "SUCCESS" diff --git a/test_app/tests/test_where_clause.py b/test_app/tests/test_where_clause.py index e4613e2..47229ff 100644 --- a/test_app/tests/test_where_clause.py +++ b/test_app/tests/test_where_clause.py @@ -4,7 +4,7 @@ import pytest from platformics.database.connect import SyncDB -from conftest import GQLTestClient, SessionStorage, FileFactory +from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory from test_infra.factories.sequencing_read import SequencingReadFactory from support.enums import SequencingTechnology @@ -254,7 +254,6 @@ async def test_soft_deleted_objects(sync_db: SyncDB, gql_client: GQLTestClient) By default, soft-deleted objects should not be returned. """ sequencing_reads = generate_sequencing_reads(sync_db) - FileFactory.update_file_ids() # Soft delete the first 3 sequencing reads by updating the deleted_at field deleted_ids = [str(sequencing_reads[0].id), str(sequencing_reads[1].id), str(sequencing_reads[2].id)] soft_delete_mutation = f"""