Skip to content

Commit

Permalink
Merge pull request #860
Browse files Browse the repository at this point in the history
Squashed from multiple commits

* Adds modify_entries to qcfractal socket and routes.

* Adds modify_entries to qcportal.

* Adds testing functions for the modify_entries behavior.
  • Loading branch information
sjayellis authored Dec 19, 2024
1 parent 700e30d commit 4bd800c
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 4 deletions.
12 changes: 11 additions & 1 deletion qcfractal/qcfractal/components/dataset_routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Any

from flask import current_app, g

Expand All @@ -20,6 +20,7 @@
DatasetModifyMetadata,
DatasetQueryRecords,
DatasetDeleteParams,
DatasetModifyEntryBody,
)
from qcportal.exceptions import LimitExceededError

Expand Down Expand Up @@ -267,6 +268,15 @@ def rename_dataset_entries_v1(dataset_type: str, dataset_id: int, body_data: Dic
return ds_socket.rename_entries(dataset_id, body_data)


@api_v1.route("/datasets/<string:dataset_type>/<int:dataset_id>/entries/modify", methods=["PATCH"])
@wrap_route("WRITE")
def modify_dataset_entries_v1(dataset_type: str, dataset_id: int, body_data: DatasetModifyEntryBody):
ds_socket = storage_socket.datasets.get_socket(dataset_type)
return ds_socket.modify_entries(
dataset_id, body_data.attribute_map, body_data.comment_map, body_data.overwrite_attributes
)


#########################
# Records
#########################
Expand Down
63 changes: 63 additions & 0 deletions qcfractal/qcfractal/components/dataset_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sqlalchemy import select, delete, func, union, text, and_
from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic
from sqlalchemy.orm.attributes import flag_modified

from qcfractal.components.dataset_db_models import BaseDatasetORM, ContributedValuesORM
from qcfractal.components.record_db_models import BaseRecordORM
Expand Down Expand Up @@ -983,6 +984,68 @@ def rename_entries(self, dataset_id: int, entry_name_map: Dict[str, str], *, ses
for entry in entries:
entry.name = entry_name_map[entry.name]

def modify_entries(
self,
dataset_id: int,
attribute_map: Optional[Dict[str, Dict[str, Any]]] = None,
comment_map: Optional[Dict[str, str]] = None,
overwrite_attributes: bool = False,
*,
session: Optional[Session] = None,
):
"""
Modify the attributes of the entries in a dataset.
If overwrite_attributes is True, replaces existing attribute entry with the value in attribute_map.
If overwrite_attributes is False, updates existing fields within attributes and adds non-existing fields.
The attribute_map maps the name of the entry to the new attribute data.
The comment_map maps the name of an entry to the comment.
Parameters
----------
dataset_id
ID of a dataset
attribute_map
Mapping of entry names to attributes.
comment_map
Mapping of entry names to comments
overwrite_attributes
Boolean to indicate if existing entries should be overwritten.
session
An existing SQLAlchemy session to use. If None, one will be created. If an existing session
is used, it will be flushed (but not committed) before returning from this function.
"""
stmt = select(self.entry_orm)
stmt = stmt.where(self.entry_orm.dataset_id == dataset_id)

stmt = stmt.where(
self.entry_orm.name.in_(
(attribute_map.keys() if (attribute_map is not None) else set())
| (comment_map.keys() if (comment_map is not None) else set())
)
)
stmt = stmt.options(load_only(self.entry_orm.name, self.entry_orm.attributes, self.entry_orm.comment))
stmt = stmt.options(lazyload("*"))
stmt = stmt.with_for_update(skip_locked=False)

attribute_keys = attribute_map.keys() if (attribute_map is not None) else list()
comment_keys = comment_map.keys() if (comment_map is not None) else list()

with self.root_socket.optional_session(session) as session:
entries = session.execute(stmt).scalars().all()

for entry in entries:
if overwrite_attributes:
if entry.name in attribute_keys:
entry.attributes = attribute_map[entry.name]
else:
if entry.name in attribute_keys:
entry.attributes.update(attribute_map[entry.name])
flag_modified(entry, "attributes")

if entry.name in comment_keys:
entry.comment = comment_map[entry.name]

def fetch_records(
self,
dataset_id: int,
Expand Down
32 changes: 32 additions & 0 deletions qcportal/qcportal/dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,32 @@ def rename_entries(self, name_map: Dict[str, str]):
for old_name, new_name in name_map.items():
self._cache_data.rename_entry(old_name, new_name)

def modify_entries(
self,
attribute_map: Optional[Dict[str, Dict[str, Any]]] = None,
comment_map: Optional[Dict[str, str]] = None,
overwrite_attributes: bool = False,
):
self.assert_is_not_view()
self.assert_online()

body = DatasetModifyEntryBody(
attribute_map=attribute_map, comment_map=comment_map, overwrite_attributes=overwrite_attributes
)

self._client.make_request(
"patch", f"api/v1/datasets/{self.dataset_type}/{self.id}/entries/modify", None, body=body
)

# Sync local cache with updated server.
entries_to_sync = set()
if attribute_map is not None:
entries_to_sync = entries_to_sync | attribute_map.keys()
if comment_map is not None:
entries_to_sync = entries_to_sync | comment_map.keys()

self.fetch_entries(entries_to_sync, force_refetch=True)

def delete_entries(self, names: Union[str, Iterable[str]], delete_records: bool = False) -> DeleteMetadata:
self.assert_is_not_view()
self.assert_online()
Expand Down Expand Up @@ -1796,6 +1822,12 @@ class DatasetDeleteSpecificationBody(RestModelBase):
delete_records: bool = False


class DatasetModifyEntryBody(RestModelBase):
attribute_map: Optional[Dict[str, Dict[str, Any]]] = None
comment_map: Optional[Dict[str, str]] = None
overwrite_attributes: bool = False


def dataset_from_dict(data: Dict[str, Any], client: Any, cache_data: Optional[DatasetCache] = None) -> BaseDataset:
"""
Create a dataset object from a datamodel
Expand Down
55 changes: 55 additions & 0 deletions qcportal/qcportal/dataset_testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,61 @@ def run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_spec
assert ds._cache_data.get_dataset_record(entry_name_3, "spec_1").id == ent_rec_map[entry_name_3].id


def run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs):
ds.add_specification("spec_1", test_specs[0])
ds.add_entries(test_entries)
ds.fetch_entries()

entry_name_2 = test_entries[1].name

expected_attribute_value = test_entries[1].attributes | {"test_attr_1": "val", "test_attr_2": 5}

# Test Overwrite=False
# Test modifying one entry attribute with no comments
ds.modify_entries(attribute_map={entry_name_2: {"test_attr_1": "val", "test_attr_2": 5}})
assert ds.get_entry(entry_name_2).attributes == expected_attribute_value

expected_attribute_value.update({"test_attr_1": "new_val", "test_attr_2": 10})
ds.modify_entries(attribute_map={entry_name_2: {"test_attr_1": "new_val", "test_attr_2": 10}})
assert ds.get_entry(entry_name_2).attributes == expected_attribute_value

# Test modifying both
expected_attribute_value.update({"test_attr_1": "new_value", "test_attr_2": 19})
ds.modify_entries(
attribute_map={entry_name_2: {"test_attr_1": "new_value", "test_attr_2": 19}},
comment_map={entry_name_2: "This is a new comment for the entry."},
)
assert ds.get_entry(entry_name_2).attributes == expected_attribute_value
assert ds.get_entry(entry_name_2).comment == "This is a new comment for the entry."

# Test Overwrite=True
# Test modifying one entry attribute with no comments
expected_attribute_value = {"test_attr_1": "val", "test_attr_2": 5}
ds.modify_entries(attribute_map={entry_name_2: {"test_attr_1": "val", "test_attr_2": 5}}, overwrite_attributes=True)
assert ds.get_entry(entry_name_2).attributes == expected_attribute_value

# Test modifying one comment with no attributes
ds.modify_entries(comment_map={entry_name_2: "This is a new comment tested without modifying attributes."})
assert ds.get_entry(entry_name_2).attributes == expected_attribute_value
assert ds.get_entry(entry_name_2).comment == "This is a new comment tested without modifying attributes."

# Test modifying both
expected_attribute_value = {"test_attr_1": "value"}
ds.modify_entries(
attribute_map={entry_name_2: {"test_attr_1": "value"}},
comment_map={entry_name_2: "This is a new comment while overwriting the attributes."},
overwrite_attributes=True,
)
assert ds.get_entry(entry_name_2).attributes == expected_attribute_value
assert ds.get_entry(entry_name_2).comment == "This is a new comment while overwriting the attributes."

# Now with a fresh dataset
ds = snowflake_client.get_dataset_by_id(ds.id)
ds.fetch_entries()
assert ds.get_entry(entry_name_2).attributes == expected_attribute_value
assert ds.get_entry(entry_name_2).comment == "This is a new comment while overwriting the attributes."


def run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs):
ds.add_specification("spec_1", test_specs[0])
ds.add_entries(test_entries)
Expand Down
5 changes: 5 additions & 0 deletions qcportal/qcportal/gridoptimization/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def test_gridoptimization_dataset_model_rename_entry(snowflake_client: PortalCli
ds_helpers.run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_specs)


def test_gridoptimization_dataset_model_modify_entries(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("gridoptimization", "Test dataset")
ds_helpers.run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs)


def test_gridoptimization_dataset_model_delete_entry(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("gridoptimization", "Test dataset")
ds_helpers.run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs)
Expand Down
5 changes: 5 additions & 0 deletions qcportal/qcportal/manybody/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def test_manybody_dataset_model_rename_entry(snowflake_client: PortalClient):
ds_helpers.run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_specs)


def test_manybody_dataset_model_modify_entries(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("manybody", "Test dataset")
ds_helpers.run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs)


def test_manybody_dataset_model_delete_entry(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("manybody", "Test dataset")
ds_helpers.run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs)
Expand Down
5 changes: 5 additions & 0 deletions qcportal/qcportal/neb/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def test_neb_dataset_model_rename_entry(snowflake_client: PortalClient):
ds_helpers.run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_specs)


def test_neb_dataset_model_modify_entries(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("neb", "Test dataset")
ds_helpers.run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs)


def test_neb_dataset_model_delete_entry(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("neb", "Test dataset")
ds_helpers.run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs)
Expand Down
5 changes: 5 additions & 0 deletions qcportal/qcportal/optimization/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def test_optimization_dataset_model_rename_entry(snowflake_client: PortalClient)
ds_helpers.run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_specs)


def test_optimization_dataset_model_modify_entries(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("optimization", "Test dataset")
ds_helpers.run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs)


def test_optimization_dataset_model_delete_entry(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("optimization", "Test dataset")
ds_helpers.run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs)
Expand Down
5 changes: 5 additions & 0 deletions qcportal/qcportal/reaction/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def test_reaction_dataset_model_rename_entry(snowflake_client: PortalClient):
ds_helpers.run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_specs)


def test_reaction_dataset_model_modify_entries(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("reaction", "Test dataset")
ds_helpers.run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs)


def test_reaction_dataset_model_delete_entry(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("reaction", "Test dataset")
ds_helpers.run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs)
Expand Down
5 changes: 5 additions & 0 deletions qcportal/qcportal/singlepoint/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def test_singlepoint_dataset_model_rename_entry(snowflake_client: PortalClient):
ds_helpers.run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_specs)


def test_singlepoint_dataset_model_modify_entries(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("singlepoint", "Test dataset")
ds_helpers.run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs)


def test_singlepoint_dataset_model_delete_entry(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("singlepoint", "Test dataset")
ds_helpers.run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs)
Expand Down
5 changes: 5 additions & 0 deletions qcportal/qcportal/torsiondrive/test_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def test_torsiondrive_dataset_model_rename_entry(snowflake_client: PortalClient)
ds_helpers.run_dataset_model_rename_entry(snowflake_client, ds, test_entries, test_specs)


def test_torsiondrive_dataset_model_modify_entries(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("torsiondrive", "Test dataset")
ds_helpers.run_dataset_model_modify_entries(snowflake_client, ds, test_entries, test_specs)


def test_torsiondrive_dataset_model_delete_entry(snowflake_client: PortalClient):
ds = snowflake_client.add_dataset("torsiondrive", "Test dataset")
ds_helpers.run_dataset_model_delete_entry(snowflake_client, ds, test_entries, test_specs)
Expand Down
16 changes: 13 additions & 3 deletions qcportal/qcportal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from hashlib import sha256
from typing import Optional, Union, Sequence, List, TypeVar, Any, Dict, Generator, Iterable, Callable
from typing import Optional, Union, Sequence, List, TypeVar, Any, Dict, Generator, Iterable, Callable, Set

import numpy as np

Expand All @@ -19,16 +19,26 @@
_T = TypeVar("_T")


def make_list(obj: Optional[Union[_T, Sequence[_T]]]) -> Optional[List[_T]]:
def make_list(obj: Optional[Union[_T, Sequence[_T], Set[_T]]]) -> Optional[List[_T]]:
"""
Returns a list containing obj if obj is not a list or sequence type object
Returns a list containing obj if obj is not a list or other iterable type object
This will also work with sets
"""

# NOTE - you might be tempted to change this to work with Iterable rather than Sequence. However,
# pydantic models and dicts and stuff are sequences, too, which we usually just want to return
# within a list

if isinstance(obj, list):
return obj
if obj is None:
return None
# Be careful. strings are sequences
if isinstance(obj, str):
return [obj]
if isinstance(obj, set):
return list(obj)
if not isinstance(obj, Sequence):
return [obj]
return list(obj)
Expand Down

0 comments on commit 4bd800c

Please sign in to comment.