Skip to content

Commit

Permalink
feat!: remove file handling (#124)
Browse files Browse the repository at this point in the history
* Updating dependencies.

* Upgrade strawberry-graphql.

* Allow Strawberry's GQL errors to pass through without changes.

* Remove file handling.
  • Loading branch information
jgadling authored Feb 5, 2025
1 parent bccb306 commit ac03bce
Show file tree
Hide file tree
Showing 32 changed files with 118 additions and 1,874 deletions.
4 changes: 1 addition & 3 deletions platformics/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,18 @@ def api() -> None:
@api.command("generate")
@click.option("--schemafile", type=str, required=True)
@click.option("--output-prefix", type=str, required=True)
@click.option("--render-files/--skip-render-files", type=bool, default=True, show_default=True)
@click.option("--template-override-paths", type=str, multiple=True)
@click.pass_context
def api_generate(
ctx: click.Context,
schemafile: str,
output_prefix: str,
render_files: bool,
template_override_paths: tuple[str],
) -> None:
"""
Launch code generation
"""
generate(schemafile, output_prefix, render_files, template_override_paths)
generate(schemafile, output_prefix, template_override_paths)


@cli.group()
Expand Down
15 changes: 2 additions & 13 deletions platformics/codegen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def generate_entity_subclass_files(
template_filename: str,
environment: Environment,
view: ViewWrapper,
render_files: bool,
) -> None:
"""
Code generation for SQLAlchemy models, GraphQL types, Cerbos policies, and Factoryboy factories
Expand All @@ -76,13 +75,11 @@ def generate_entity_subclass_files(
override_template = environment.get_template(f"{dest_filename}.j2")
content = override_template.render(
cls=entity,
render_files=render_files,
view=view,
)
else:
content = template.render(
cls=entity,
render_files=render_files,
view=view,
)
with open(os.path.join(output_prefix, dest_filename), mode="w", encoding="utf-8") as outfile:
Expand All @@ -94,7 +91,6 @@ def generate_entity_import_files(
output_prefix: str,
environment: Environment,
view: ViewWrapper,
render_files: bool,
) -> None:
"""
Code generation for database model imports, and GraphQL queries/mutations
Expand All @@ -116,7 +112,6 @@ def generate_entity_import_files(
import_template = environment.get_template(f"{filename}.j2")
content = import_template.render(
classes=classes,
render_files=render_files,
view=view,
)
with open(os.path.join(output_prefix, filename), mode="w", encoding="utf-8") as outfile:
Expand All @@ -134,7 +129,7 @@ def regex_replace(txt, rgx, val, ignorecase=False, multiline=False):
return compiled_rgx.sub(val, txt)


def generate(schemafile: str, output_prefix: str, render_files: bool, template_override_paths: tuple[str]) -> None:
def generate(schemafile: str, output_prefix: str, template_override_paths: tuple[str]) -> None:
"""
Launch code generation
"""
Expand All @@ -157,48 +152,42 @@ def generate(schemafile: str, output_prefix: str, render_files: bool, template_o
# Generate enums and import files
generate_enums(output_prefix, environment, wrapped_view)
generate_limit_offset_type(output_prefix, environment)
generate_entity_import_files(output_prefix, environment, wrapped_view, render_files=render_files)
generate_entity_import_files(output_prefix, environment, wrapped_view)

# Generate database models, GraphQL types, Cerbos policies, and Factoryboy factories
generate_entity_subclass_files(
output_prefix,
"database/models/class_name.py",
environment,
wrapped_view,
render_files=render_files,
)
generate_entity_subclass_files(
output_prefix,
"graphql_api/types/class_name.py",
environment,
wrapped_view,
render_files=render_files,
)
generate_entity_subclass_files(
output_prefix,
"validators/class_name.py",
environment,
wrapped_view,
render_files=render_files,
)
generate_entity_subclass_files(
output_prefix,
"cerbos/policies/class_name.yaml",
environment,
wrapped_view,
render_files=render_files,
)
generate_entity_subclass_files(
output_prefix,
"test_infra/factories/class_name.py",
environment,
wrapped_view,
render_files=render_files,
)
generate_entity_subclass_files(
output_prefix,
"graphql_api/helpers/class_name.py",
environment,
wrapped_view,
render_files=render_files,
)
59 changes: 43 additions & 16 deletions platformics/codegen/lib/linkml_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
functions to keep complicated LinkML-specific logic out of our Jinja2 templates.
"""

import contextlib
from functools import cached_property

import strcase
Expand Down Expand Up @@ -35,10 +36,19 @@ def identifier(self) -> str:
def name(self) -> str:
return self.wrapped_field.name.replace(" ", "_")

@cached_property
def description(self) -> str:
# Make sure to quote this so it's safe!
return repr(self.wrapped_field.description)

@cached_property
def camel_name(self) -> str:
return strcase.to_lower_camel(self.name)

@cached_property
def type_designator(self) -> bool:
return bool(self.wrapped_field.designates_type)

@cached_property
def multivalued(self) -> str:
return self.wrapped_field.multivalued
Expand All @@ -47,6 +57,10 @@ def multivalued(self) -> str:
def required(self) -> bool:
return self.wrapped_field.required or False

@cached_property
def designates_type(self) -> bool:
return self.wrapped_field.designates_type

# Validation attributes
@cached_property
def minimum_value(self) -> float | int | None:
Expand All @@ -64,6 +78,11 @@ def maximum_value(self) -> float | int | None:
def indexed(self) -> bool:
if "indexed" in self.wrapped_field.annotations:
return self.wrapped_field.annotations["indexed"].value
if self.identifier:
return True
with contextlib.suppress(NotImplementedError, AttributeError, ValueError):
if self.related_class.identifier:
return True
return False

@cached_property
Expand Down Expand Up @@ -94,9 +113,7 @@ def hidden(self) -> bool:
@cached_property
def readonly(self) -> bool:
is_readonly = self.wrapped_field.readonly
if is_readonly:
return True
return False
return bool(is_readonly)

# Whether these fields should be available to change via an `Update` mutation
# All fields are mutable by default, so long as they're not marked as readonly
Expand Down Expand Up @@ -138,16 +155,12 @@ def inverse_field(self) -> str:
@cached_property
def is_enum(self) -> bool:
field = self.view.get_element(self.wrapped_field.range)
if isinstance(field, EnumDefinition):
return True
return False
return bool(isinstance(field, EnumDefinition))

@cached_property
def is_entity(self) -> bool:
field = self.view.get_element(self.wrapped_field.range)
if isinstance(field, ClassDefinition):
return True
return False
return bool(isinstance(field, ClassDefinition))

@property
def related_class(self) -> "EntityWrapper":
Expand All @@ -163,6 +176,17 @@ def factory_type(self) -> str | None:
return self.wrapped_field.annotations["factory_type"].value
return None

@cached_property
def is_single_parent(self) -> bool:
# TODO, this parameter probably needs a better name. It's entirely SA specific right now.
# Basically we need it to tell SQLAlchemy that we have a 1:many relationship without a backref.
# Normally that's fine on its own, but if SQLALchemy will not allow us to enable cascading
# deletes unless we promise it (with this flag) that a given "child" object has only one parent,
# thereby making it safe to delete when the parent is deleted
if "single_parent" in self.wrapped_field.annotations:
return self.wrapped_field.annotations["single_parent"].value
return False

@cached_property
def is_cascade_delete(self) -> bool:
if "cascade_delete" in self.wrapped_field.annotations:
Expand Down Expand Up @@ -234,15 +258,15 @@ def writable_fields(self) -> list[FieldWrapper]:
return [FieldWrapper(self.view, item) for item in self.view.class_induced_slots(self.name) if not item.readonly]

@cached_property
def identifier(self) -> str:
# Prioritize sending back identifiers from the entity mixin instead of inherited fields.
def identifier(self) -> FieldWrapper:
# Prioritize sending back identifiers from the current class and mixins instead of inherited fields.
domains_owned_by_this_class = set(self.wrapped_class.mixins + [self.name])
for field in self.all_fields:
# FIXME, the entity.id / entity_id relationship is a little brittle right now :(
if field.identifier and "EntityMixin" in field.wrapped_field.domain_of:
return field.name
if field.identifier and domains_owned_by_this_class.intersection(set(field.wrapped_field.domain_of)):
return field
for field in self.all_fields:
if field.identifier:
return field.name
if field.identifier and self.name in field.wrapped_field.domain_of:
return field
raise Exception("No identifier found")

@cached_property
Expand Down Expand Up @@ -343,6 +367,9 @@ def enums(self) -> list[EnumWrapper]:
enums = []
for enum_name in self.view.all_enums():
enum = self.view.get_element(enum_name)
# Don't codegen stuff that users asked us not to.
if enum.annotations.get("skip_codegen") and enum.annotations["skip_codegen"].value:
continue
enums.append(EnumWrapper(self.view, enum))
return enums

Expand Down
4 changes: 1 addition & 3 deletions platformics/codegen/templates/database/models/__init__.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ Make changes to the template codegen/templates/database/models/__init__.py.j2 in

from sqlalchemy.orm import configure_mappers

from platformics.database.models import Base, meta, Entity, File, FileStatus # noqa: F401
from platformics.database.models import Base, meta, Entity # noqa: F401
{%- for class in classes %}
{%- if class.snake_name != "Entity" %}
from database.models.{{ class.snake_name }} import {{ class.name }} # noqa: F401
{%- endif %}
{%- endfor %}

from platformics.database.models.file import File, FileStatus # noqa: F401

configure_mappers()
21 changes: 12 additions & 9 deletions platformics/codegen/templates/database/models/class_name.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Make changes to the template codegen/templates/database/models/class_name.py.j2
"""

{% set related_fields = cls.related_fields | unique(attribute='related_class.name') | list %}
{% set ignored_fields = ["File", "Entity", cls.name] %}
{% set ignored_fields = ["Entity", cls.name] %}

import uuid
import datetime
Expand All @@ -26,19 +26,19 @@ from support.enums import
{%- endfor %}

if TYPE_CHECKING:
from database.models.file import File
{%- for related_field in related_fields %}
{%- if related_field.related_class.name not in ignored_fields %}
from database.models.{{related_field.related_class.snake_name}} import {{related_field.related_class.name}}
{%- endif %}
{%- endfor %}
pass
else:
File = "File"
{%- for related_field in related_fields %}
{%- if related_field.related_class.name not in ignored_fields %}
{{related_field.related_class.name}} = "{{related_field.related_class.name}}"
{%- endif %}
{%- endfor %}
pass


class {{cls.name}}(Entity):
Expand All @@ -65,12 +65,13 @@ class {{cls.name}}(Entity):
{%- elif attr.type == "date" %}
{{attr.name}}: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True), nullable={{ "False" if attr.required else "True"}}{%- if attr.indexed%}, index=True{%- endif %})
{%- else %}
{%- if attr.is_single_parent -%}
{% set single_parent = 'single_parent=True,' %}
{%- else -%}
{% set single_parent = "" %}
{%- endif -%}
{%- if attr.is_cascade_delete -%}
{%- if attr.type == "File" -%}
{% set cascade = 'cascade="all, delete-orphan", single_parent=True, post_update=True' %}
{%- else -%}
{% set cascade = 'cascade="all, delete-orphan"' %}
{%- endif -%}
{% set cascade = 'cascade="all, delete-orphan",' %}
{%- else -%}
{% set cascade = "" %}
{%- endif -%}
Expand All @@ -83,11 +84,12 @@ class {{cls.name}}(Entity):
uselist=True,
foreign_keys="{{attr.type}}.{{attr.inverse_field}}_id",
{{cascade}}
{{single_parent}}
)
{%- else %}
{{attr.name}}_id: Mapped[uuid.UUID] = mapped_column(
UUID,
ForeignKey("{{attr.related_class.snake_name}}.{{attr.related_class.identifier}}"),
ForeignKey("{{attr.related_class.snake_name}}.{{attr.related_class.identifier.name}}"),
nullable={{"False" if attr.required else "True"}},
{%- if attr.identifier %}
primary_key=True,
Expand All @@ -103,6 +105,7 @@ class {{cls.name}}(Entity):
back_populates="{{attr.inverse_field}}",
{%- endif %}
{{cascade}}
{{single_parent}}
)
{%- endif %}
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Auto-gereanted by running 'make codegen'. Do not edit.
Make changes to the template codegen/templates/graphql_api/groupby_helpers.py.j2 instead.
"""
{% set related_fields = cls.related_fields | unique(attribute='related_class.name') | list %}
{% set ignored_fields = ["File", "Entity", cls.name] %}
{% set ignored_fields = ["Entity", cls.name] %}

from typing import Any, Optional
import strawberry
Expand Down Expand Up @@ -35,8 +35,8 @@ These are only used in aggregate queries.
@strawberry.type
class {{ cls.name }}GroupByOptions:
{%- for attr in cls.visible_fields %}
{%- if attr.type != "File" and not attr.multivalued and not attr.is_virtual_relationship %}
{%- if attr.inverse %}
{%- if not attr.multivalued and not attr.is_virtual_relationship %}
{%- if attr.is_entity %}
{{ attr.name }}: Optional[{{ attr.related_class.name }}GroupByOptions] = None
{%- elif attr.type == cls.name %}
{{ attr.name }}: Optional["{{ cls.name }}GroupByOptions"] = None
Expand Down Expand Up @@ -78,7 +78,7 @@ def build_{{ cls.snake_name }}_groupby_output(
key = keys.pop(0)
match key:
{%- for attr in cls.visible_fields %}
{%- if attr.type != "File" and not attr.multivalued and not attr.is_virtual_relationship %}
{%- if not attr.multivalued and not attr.is_virtual_relationship %}
{%- if attr.inverse %}
case "{{ attr.related_class.snake_name }}":
if getattr(group_object, key):
Expand Down
11 changes: 0 additions & 11 deletions platformics/codegen/templates/graphql_api/mutations.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,13 @@ Make changes to the template codegen/templates/graphql_api/mutations.py.j2 inste

import strawberry
from typing import Sequence
{%- if render_files %}
from platformics.graphql_api.files import File, create_file, upload_file, upload_temporary_file, mark_upload_complete, concatenate_files, SignedURL, MultipartUploadResponse
{%- endif %}

{%- for class in classes %}
from graphql_api.types.{{ class.snake_name }} import {{ class.name }}, {%- if class.create_fields %}create_{{ class.snake_name }}, {%- endif %}{%- if class.mutable_fields %}update_{{ class.snake_name }}, {%- endif %}delete_{{ class.snake_name }}
{%- endfor %}

@strawberry.type
class Mutation:
{%- if render_files %}
# File mutations
create_file: File = create_file
upload_file: MultipartUploadResponse = upload_file
upload_temporary_file: MultipartUploadResponse = upload_temporary_file
mark_upload_complete: File = mark_upload_complete
concatenate_files: SignedURL = concatenate_files
{%- endif %}
{%- for class in classes %}

# {{ class.name }} mutations
Expand Down
Loading

0 comments on commit ac03bce

Please sign in to comment.