diff --git a/CHANGELOG.md b/CHANGELOG.md
index d192b98..f68f61c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,10 @@ The format is roughly based on [Keep a Changelog](https://keepachangelog.com/en/
+## 0.39.0
+
+- More advanced process-graph splitting for cross-backend execution: not limited to splitting off `load_collection` nodes, but cut deeper into the graph. ([#150](https://github.com/Open-EO/openeo-aggregator/issues/150))
+
## 0.38.0
- Add request timeout configs for listing user jobs (eu-cdse/openeo-cdse-infra#188)
diff --git a/scripts/crossbackend-processing-poc.py b/scripts/crossbackend-processing-poc.py
index e8b83e5..01f6a33 100644
--- a/scripts/crossbackend-processing-poc.py
+++ b/scripts/crossbackend-processing-poc.py
@@ -7,7 +7,8 @@
from openeo_aggregator.metadata import STAC_PROPERTY_FEDERATION_BACKENDS
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.crossbackend import (
- CrossBackendSplitter,
+ CrossBackendJobSplitter,
+ LoadCollectionGraphSplitter,
run_partitioned_job,
)
@@ -62,7 +63,9 @@ def backend_for_collection(collection_id) -> str:
metadata = connection.describe_collection(collection_id)
return metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS][0]
- splitter = CrossBackendSplitter(backend_for_collection=backend_for_collection, always_split=True)
+ splitter = CrossBackendJobSplitter(
+ graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=backend_for_collection, always_split=True)
+ )
pjob: PartitionedJob = splitter.split({"process_graph": process_graph})
_log.info(f"Partitioned job: {pjob!r}")
diff --git a/src/openeo_aggregator/about.py b/src/openeo_aggregator/about.py
index 5287e3d..febd27f 100644
--- a/src/openeo_aggregator/about.py
+++ b/src/openeo_aggregator/about.py
@@ -2,7 +2,7 @@
import sys
from typing import Optional
-__version__ = "0.38.0a1"
+__version__ = "0.39.0a1"
def log_version_info(logger: Optional[logging.Logger] = None):
diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py
index b568ce3..398aaeb 100644
--- a/src/openeo_aggregator/backend.py
+++ b/src/openeo_aggregator/backend.py
@@ -84,6 +84,7 @@
streaming_flask_response,
)
from openeo_aggregator.constants import (
+ CROSSBACKEND_GRAPH_SPLIT_METHOD,
JOB_OPTION_FORCE_BACKEND,
JOB_OPTION_SPLIT_STRATEGY,
JOB_OPTION_TILE_GRID,
@@ -100,7 +101,11 @@
single_backend_collection_post_processing,
)
from openeo_aggregator.partitionedjobs import PartitionedJob
-from openeo_aggregator.partitionedjobs.crossbackend import CrossBackendSplitter
+from openeo_aggregator.partitionedjobs.crossbackend import (
+ CrossBackendJobSplitter,
+ DeepGraphSplitter,
+ LoadCollectionGraphSplitter,
+)
from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter
from openeo_aggregator.partitionedjobs.tracking import (
PartitionedJobConnection,
@@ -803,25 +808,34 @@ def create_job(
if "process_graph" not in process:
raise ProcessGraphMissingException()
- # TODO: better, more generic/specific job_option(s)?
- if job_options and (job_options.get(JOB_OPTION_SPLIT_STRATEGY) or job_options.get(JOB_OPTION_TILE_GRID)):
- if job_options.get(JOB_OPTION_SPLIT_STRATEGY) == "crossbackend":
- # TODO this is temporary feature flag to trigger "crossbackend" splitting
- return self._create_crossbackend_job(
- user_id=user_id,
- process=process,
- api_version=api_version,
- metadata=metadata,
- job_options=job_options,
- )
- else:
- return self._create_partitioned_job(
- user_id=user_id,
- process=process,
- api_version=api_version,
- metadata=metadata,
- job_options=job_options,
- )
+ # Coverage of messy "split_strategy" job option
+ # Also see https://github.com/Open-EO/openeo-aggregator/issues/156
+ # TODO: more generic and future proof handling of split strategy related options?
+ split_strategy = (job_options or {}).get(JOB_OPTION_SPLIT_STRATEGY)
+ tile_grid = (job_options or {}).get(JOB_OPTION_TILE_GRID)
+ crossbackend_mode = split_strategy == "crossbackend" or (
+ isinstance(split_strategy, dict) and "crossbackend" in split_strategy
+ )
+ # TODO: the legacy job option "tile_grid" is quite generic and not very explicit
+ # about being a job splitting approach. Can we deprecate this in a way?
+ spatial_split_mode = tile_grid or split_strategy == "flimsy"
+
+ if crossbackend_mode:
+ return self._create_crossbackend_job(
+ user_id=user_id,
+ process=process,
+ api_version=api_version,
+ metadata=metadata,
+ job_options=job_options,
+ )
+ elif spatial_split_mode:
+ return self._create_partitioned_job(
+ user_id=user_id,
+ process=process,
+ api_version=api_version,
+ metadata=metadata,
+ job_options=job_options,
+ )
else:
return self._create_job_standard(
user_id=user_id,
@@ -936,14 +950,47 @@ def _create_crossbackend_job(
if not self.partitioned_job_tracker:
raise FeatureUnsupportedException(message="Partitioned job tracking is not supported")
- def backend_for_collection(collection_id) -> str:
- return self._catalog.get_backends_for_collection(cid=collection_id)[0]
+ split_strategy = (job_options or {}).get(JOB_OPTION_SPLIT_STRATEGY)
+ if split_strategy == "crossbackend":
+ # Legacy job option format
+ graph_split_method = CROSSBACKEND_GRAPH_SPLIT_METHOD.SIMPLE
+ elif isinstance(split_strategy, dict) and isinstance(split_strategy.get("crossbackend"), dict):
+ graph_split_method = split_strategy.get("crossbackend", {}).get(
+ "method", CROSSBACKEND_GRAPH_SPLIT_METHOD.SIMPLE
+ )
+ else:
+ raise ValueError(f"Invalid split strategy {split_strategy!r}")
- splitter = CrossBackendSplitter(
- backend_for_collection=backend_for_collection,
- # TODO: job option for `always_split` feature?
- always_split=True,
- )
+ _log.info(f"_create_crossbackend_job: {graph_split_method=} from {split_strategy=}")
+ if graph_split_method == CROSSBACKEND_GRAPH_SPLIT_METHOD.SIMPLE:
+
+ def backend_for_collection(collection_id) -> str:
+ return self._catalog.get_backends_for_collection(cid=collection_id)[0]
+
+ graph_splitter = LoadCollectionGraphSplitter(
+ backend_for_collection=backend_for_collection,
+ # TODO: job option for `always_split` feature?
+ always_split=True,
+ )
+ elif graph_split_method == CROSSBACKEND_GRAPH_SPLIT_METHOD.DEEP:
+
+ def supporting_backends(node_id: str, node: dict) -> Union[List[str], None]:
+ # TODO: wider coverage checking process id availability
+ if node["process_id"] == "load_collection":
+ collection_id = node["arguments"]["id"]
+ return self._catalog.get_backends_for_collection(cid=collection_id)
+
+ graph_splitter = DeepGraphSplitter(
+ supporting_backends=supporting_backends,
+ primary_backend=split_strategy.get("crossbackend", {}).get("primary_backend"),
+ # TODO: instead of this hardcoded deny-list, build it based on backend metadata inspection?
+ # TODO: make a config for this?
+ split_deny_list={"aggregate_spatial", "load_geojson", "load_url"},
+ )
+ else:
+ raise ValueError(f"Invalid graph split strategy {graph_split_method!r}")
+
+ splitter = CrossBackendJobSplitter(graph_splitter=graph_splitter)
pjob_id = self.partitioned_job_tracker.create_crossbackend_pjob(
user_id=user_id, process=process, metadata=metadata, job_options=job_options, splitter=splitter
diff --git a/src/openeo_aggregator/constants.py b/src/openeo_aggregator/constants.py
index 68621bd..02a7a4a 100644
--- a/src/openeo_aggregator/constants.py
+++ b/src/openeo_aggregator/constants.py
@@ -4,3 +4,9 @@
# Experimental feature to force a certain upstream back-end through job options
JOB_OPTION_FORCE_BACKEND = "_agg_force_backend"
+
+
+class CROSSBACKEND_GRAPH_SPLIT_METHOD:
+ # Poor-man's StrEnum
+ SIMPLE = "simple"
+ DEEP = "deep"
diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py
index 9c97300..076b021 100644
--- a/src/openeo_aggregator/partitionedjobs/crossbackend.py
+++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py
@@ -1,27 +1,68 @@
+from __future__ import annotations
+
+import abc
import collections
import copy
+import dataclasses
import datetime
+import fractions
+import functools
import itertools
import logging
import time
+import types
from contextlib import nullcontext
-from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple
+from typing import (
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Mapping,
+ NamedTuple,
+ Optional,
+ Protocol,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
import openeo
from openeo import BatchJob
+from openeo.util import deep_get
from openeo_driver.jobregistry import JOB_STATUS
from openeo_aggregator.constants import JOB_OPTION_FORCE_BACKEND
from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob
from openeo_aggregator.partitionedjobs.splitting import AbstractJobSplitter
-from openeo_aggregator.utils import FlatPG, PGWithMetadata, SkipIntermittentFailures
+from openeo_aggregator.utils import (
+ _UNSET,
+ FlatPG,
+ PGWithMetadata,
+ SkipIntermittentFailures,
+)
_log = logging.getLogger(__name__)
_LOAD_RESULT_PLACEHOLDER = "_placeholder:"
# Some type annotation aliases to make things more self-documenting
+CollectionId = str
SubGraphId = str
+NodeId = str
+BackendId = str
+ProcessId = str
+
+
+# Annotation for a function that maps node information (id and its node dict)
+# to id(s) of the backend(s) that support it.
+# Returning None means that support is unconstrained (any backend is assumed to support it).
+SupportingBackendsMapper = Callable[[NodeId, dict], Union[BackendId, Iterable[BackendId], None]]
+
+
+class GraphSplitException(Exception):
+ pass
class GetReplacementCallable(Protocol):
@@ -57,49 +98,57 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId)
}
-class CrossBackendSplitter(AbstractJobSplitter):
- """
- Split a process graph, to be executed across multiple back-ends,
- based on availability of collections.
+class _PGSplitSubGraph(NamedTuple):
+ """Container for result of ProcessGraphSplitterInterface.split"""
+
+ split_node: NodeId
+ node_ids: Set[NodeId]
+ backend_id: BackendId
- .. warning::
- this is experimental functionality
+class _PGSplitResult(NamedTuple):
+ """Container for result of ProcessGraphSplitterInterface.split"""
+
+ primary_node_ids: Set[NodeId]
+ primary_backend_id: BackendId
+ secondary_graphs: List[_PGSplitSubGraph]
+
+
+class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta):
+ """
+ Interface for process graph splitters:
+ given a process graph (flat graph representation),
+ produce a main graph and secondary graphs (as subsets of node ids)
+ and the backends they are supposed to run on.
"""
- def __init__(self, backend_for_collection: Callable[[str], str], always_split: bool = False):
- """
- :param backend_for_collection: callable that determines backend id for given collection id
- :param always_split: split all load_collections, also when on same backend
+ @abc.abstractmethod
+ def split(self, process_graph: FlatPG) -> _PGSplitResult:
"""
- # TODO: just handle this `backend_for_collection` callback with a regular method?
- self.backend_for_collection = backend_for_collection
- self._always_split = always_split
+ Split given process graph (flat graph representation) into sub graphs
- def split_streaming(
- self,
- process_graph: FlatPG,
- get_replacement: GetReplacementCallable = _default_get_replacement,
- main_subgraph_id: SubGraphId = "main",
- ) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]:
+ Returns primary graph data (node ids and backend id)
+ and secondary graphs data (list of tuples: split node id, subgraph node ids,backend id)
"""
- Split given process graph in sub-process graphs and return these as an iterator
- in an order so that a subgraph comes after all subgraphs it depends on
- (e.g. main "primary" graph comes last).
+ ...
- The iterator approach allows working with a dynamic `get_replacement` implementation
- that adapting to on previously produced subgraphs
- (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately).
- :return: tuple containing:
- - subgraph id, recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}')
- - SubJob
- - dependencies as list of subgraph ids
- """
+class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface):
+ """
+ Simple process graph splitter that just splits off load_collection nodes.
+ """
+
+ # TODO: migrate backend_for_collection to SupportingBackendsMapper format?
+ def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False):
+ # TODO: also support not not having a backend_for_collection map?
+ self._backend_for_collection = backend_for_collection
+ self._always_split = always_split
+
+ def split(self, process_graph: FlatPG) -> _PGSplitResult:
# Extract necessary back-ends from `load_collection` usage
backend_per_collection: Dict[str, str] = {
- cid: self.backend_for_collection(cid)
+ cid: self._backend_for_collection(cid)
for cid in (
node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection"
)
@@ -112,45 +161,93 @@ def split_streaming(
secondary_backends = {b for b in backend_usage if b != primary_backend}
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")
- primary_id = main_subgraph_id
- primary_pg = {}
primary_has_load_collection = False
- primary_dependencies = []
-
+ primary_graph_node_ids = set()
+ secondary_graphs: List[_PGSplitSubGraph] = []
for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
- # Add to primary pg
- primary_pg[node_id] = node
+ primary_graph_node_ids.add(node_id)
primary_has_load_collection = True
else:
- # New secondary pg
- sub_id = f"{bid}:{node_id}"
- sub_pg = {
- node_id: node,
- "sr1": {
- # TODO: other/better choices for save_result format (e.g. based on backend support)?
- "process_id": "save_result",
- "arguments": {
- "data": {"from_node": node_id},
- # TODO: particular format options?
- # "format": "NetCDF",
- "format": "GTiff",
- },
- "result": True,
- },
- }
+ secondary_graphs.append(_PGSplitSubGraph(split_node=node_id, node_ids={node_id}, backend_id=bid))
+ else:
+ primary_graph_node_ids.add(node_id)
- yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), [])
+ return _PGSplitResult(
+ primary_node_ids=primary_graph_node_ids,
+ primary_backend_id=primary_backend,
+ secondary_graphs=secondary_graphs,
+ )
- # Link secondary pg into primary pg
- primary_pg.update(get_replacement(node_id=node_id, node=node, subgraph_id=sub_id))
- primary_dependencies.append(sub_id)
- else:
- primary_pg[node_id] = node
- yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies)
+class CrossBackendJobSplitter(AbstractJobSplitter):
+ """
+ Split a process graph, to be executed across multiple back-ends,
+ based on availability of collections.
+
+ .. warning::
+ this is experimental functionality
+
+ """
+
+ def __init__(self, graph_splitter: ProcessGraphSplitterInterface):
+ self._graph_splitter = graph_splitter
+
+ def split_streaming(
+ self,
+ process_graph: FlatPG,
+ get_replacement: GetReplacementCallable = _default_get_replacement,
+ main_subgraph_id: SubGraphId = "main",
+ ) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]:
+ """
+ Split given process graph in sub-process graphs and return these as an iterator
+ in an order so that a subgraph comes after all subgraphs it depends on
+ (e.g. main "primary" graph comes last).
+
+ The iterator approach allows working with a dynamic `get_replacement` implementation
+ that can be adaptive to previously produced subgraphs
+ (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately).
+
+ :return: Iterator of tuples containing:
+ - subgraph id, it's recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}')
+ - SubJob
+ - dependencies as list of subgraph ids
+ """
+
+ graph_split_result = self._graph_splitter.split(process_graph=process_graph)
+
+ primary_pg = {k: process_graph[k] for k in graph_split_result.primary_node_ids}
+ primary_dependencies = []
+
+ for node_id, subgraph_node_ids, backend_id in graph_split_result.secondary_graphs:
+ # New secondary pg
+ sub_id = f"{backend_id}:{node_id}"
+ sub_pg = {k: v for k, v in process_graph.items() if k in subgraph_node_ids}
+ # Add new `save_result` node to the subgraphs
+ sub_pg["_agg_crossbackend_save_result"] = {
+ # TODO: other/better choices for save_result format (e.g. based on backend support, cube type)?
+ "process_id": "save_result",
+ "arguments": {
+ "data": {"from_node": node_id},
+ # TODO: particular format options?
+ # "format": "NetCDF",
+ "format": "GTiff",
+ },
+ "result": True,
+ }
+ yield (sub_id, SubJob(process_graph=sub_pg, backend_id=backend_id), [])
+
+ # Link secondary pg into primary pg
+ primary_pg.update(get_replacement(node_id=node_id, node=process_graph[node_id], subgraph_id=sub_id))
+ primary_dependencies.append(sub_id)
+
+ yield (
+ main_subgraph_id,
+ SubJob(process_graph=primary_pg, backend_id=graph_split_result.primary_backend_id),
+ primary_dependencies,
+ )
def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob:
"""Split given process graph into a `PartitionedJob`"""
@@ -351,3 +448,465 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai
}
for sid in subjobs.keys()
}
+
+
+def to_frozenset(value: Union[Iterable[str], str]) -> frozenset[str]:
+ """Coerce value to frozenset of strings"""
+ if isinstance(value, str):
+ value = [value]
+ return frozenset(value)
+
+
+@dataclasses.dataclass(frozen=True, init=False, eq=True)
+class _GVNode:
+ """
+ Node in a _GraphViewer, with pointers to other nodes it depends on (needs data/input from)
+ and nodes to which it provides input to.
+
+ This structure designed to be as immutable as possible (as far as Python allows)
+ to be (re)used in iterative/recursive graph handling algorithms,
+ without having to worry about accidentally propagating changed state to other parts of the graph.
+ """
+
+ # Node ids of other nodes this node depends on (aka parents)
+ depends_on: frozenset[NodeId]
+ # Node ids of other nodes that depend on this node (aka children)
+ flows_to: frozenset[NodeId]
+
+ # Backend ids this node is marked to be supported on
+ # value None means it is unknown/unconstrained for this node
+ backend_candidates: Union[frozenset[BackendId], None]
+
+ def __init__(
+ self,
+ *,
+ depends_on: Union[Iterable[NodeId], NodeId, None] = None,
+ flows_to: Union[Iterable[NodeId], NodeId, None] = None,
+ backend_candidates: Union[Iterable[BackendId], BackendId, None] = None,
+ ):
+ # TODO: type coercion in __init__ of frozen dataclasses is bit ugly. Use attrs with field converters instead?
+ super().__init__()
+ object.__setattr__(self, "depends_on", to_frozenset(depends_on or []))
+ object.__setattr__(self, "flows_to", to_frozenset(flows_to or []))
+ backend_candidates = to_frozenset(backend_candidates) if backend_candidates is not None else None
+ object.__setattr__(self, "backend_candidates", backend_candidates)
+
+ def __repr__(self):
+ # Somewhat cryptic, but compact representation of node attributes
+ depends_on = (" <" + ",".join(sorted(self.depends_on))) if self.depends_on else ""
+ flows_to = (" >" + ",".join(sorted(self.flows_to))) if self.flows_to else ""
+ backends = (" @" + ",".join(sorted(self.backend_candidates))) if self.backend_candidates else ""
+ return f"[{type(self).__name__}{depends_on}{flows_to}{backends}]"
+
+
+class _GraphViewer:
+ """
+ Internal utility to have read-only view on the topological structure of a proces graph
+ and track the flow of backend support.
+
+ """
+
+ def __init__(self, node_map: dict[NodeId, _GVNode]):
+ self._check_consistency(node_map=node_map)
+ # Work with a read-only proxy to prevent accidental changes
+ self._graph: Mapping[NodeId, _GVNode] = types.MappingProxyType(node_map.copy())
+
+ @staticmethod
+ def _check_consistency(node_map: dict[NodeId, _GVNode]):
+ """Check (link) consistency of given node map"""
+ key_ids = set(node_map.keys())
+ linked_ids = set(k for n in node_map.values() for k in n.depends_on.union(n.flows_to))
+ unknown = linked_ids.difference(key_ids)
+ if unknown:
+ raise GraphSplitException(f"Inconsistent node map: {key_ids=} != {linked_ids=}: {unknown=}")
+ bad_links = set()
+ for node_id, node in node_map.items():
+ bad_links.update((other, node_id) for other in node.depends_on if node_id not in node_map[other].flows_to)
+ bad_links.update((node_id, other) for other in node.flows_to if node_id not in node_map[other].depends_on)
+ if bad_links:
+ raise GraphSplitException(f"Inconsistent node map: {bad_links=}")
+
+ def __repr__(self):
+ return f"<{type(self).__name__}({self._graph})>"
+
+ @classmethod
+ def from_flat_graph(cls, flat_graph: FlatPG, supporting_backends: SupportingBackendsMapper = (lambda n, d: None)):
+ """
+ Build _GraphViewer from a flat process graph representation
+ """
+ _log.debug(f"_GraphViewer.from_flat_graph: {flat_graph.keys()=}")
+ # Extract dependency links between nodes
+ depends_on = collections.defaultdict(list)
+ flows_to = collections.defaultdict(list)
+ for node_id, node in flat_graph.items():
+ for arg_value in node.get("arguments", {}).values():
+ if isinstance(arg_value, dict) and list(arg_value.keys()) == ["from_node"]:
+ from_node = arg_value["from_node"]
+ depends_on[node_id].append(from_node)
+ flows_to[from_node].append(node_id)
+ graph = {
+ node_id: _GVNode(
+ depends_on=depends_on.get(node_id, []),
+ flows_to=flows_to.get(node_id, []),
+ backend_candidates=supporting_backends(node_id, node),
+ )
+ for node_id, node in flat_graph.items()
+ }
+ return cls(node_map=graph)
+
+ @classmethod
+ def from_edges(
+ cls,
+ edges: Iterable[Tuple[NodeId, NodeId]],
+ supporting_backends_mapper: SupportingBackendsMapper = (lambda n, d: None),
+ ):
+ """
+ Simple factory to build graph from parent-child tuples for testing purposes
+ """
+ depends_on = collections.defaultdict(list)
+ flows_to = collections.defaultdict(list)
+ for parent, child in edges:
+ depends_on[child].append(parent)
+ flows_to[parent].append(child)
+
+ graph = {
+ node_id: _GVNode(
+ depends_on=depends_on.get(node_id, []),
+ flows_to=flows_to.get(node_id, []),
+ backend_candidates=supporting_backends_mapper(node_id, {}),
+ )
+ for node_id in set(depends_on.keys()).union(flows_to.keys())
+ }
+ return cls(node_map=graph)
+
+ def node(self, node_id: NodeId) -> _GVNode:
+ if node_id not in self._graph:
+ raise GraphSplitException(f"Invalid node id {node_id!r}.")
+ return self._graph[node_id]
+
+ def iter_nodes(self) -> Iterator[Tuple[NodeId, _GVNode]]:
+ """Iterate through node_id-node pairs"""
+ yield from self._graph.items()
+
+ def _walk(
+ self,
+ seeds: Iterable[NodeId],
+ next_nodes: Callable[[NodeId], Iterable[NodeId]],
+ include_seeds: bool = True,
+ auto_sort: bool = True,
+ ) -> Iterator[NodeId]:
+ """
+ Walk the graph nodes starting from given seed nodes,
+ taking steps as defined by `next_nodes` function.
+ Walks breadth first and each node is only visited once.
+
+ :param include_seeds: whether to include the seed nodes in the walk
+ :param auto_sort: visit "next" nodes of a given node lexicographically sorted
+ to make the walk deterministic.
+ """
+ # TODO: option to walk depth first instead of breadth first?
+ if auto_sort:
+ # Automatically sort next nodes to make walk more deterministic
+ prepare = sorted
+ else:
+ prepare = lambda x: x
+
+ if include_seeds:
+ visited = set()
+ to_visit = list(prepare(seeds))
+ else:
+ visited = set(seeds)
+ to_visit = [n for s in seeds for n in prepare(next_nodes(s))]
+
+ while to_visit:
+ node_id = to_visit.pop(0)
+ if node_id in visited:
+ continue
+ yield node_id
+ visited.add(node_id)
+ to_visit.extend(prepare(set(next_nodes(node_id)).difference(visited)))
+
+ def walk_upstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]:
+ """
+ Walk upstream nodes (along `depends_on` link) starting from given seed nodes.
+ Optionally include seeds or not, and walk breadth first.
+ """
+ return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).depends_on, include_seeds=include_seeds)
+
+ def walk_downstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]:
+ """
+ Walk downstream nodes (along `flows_to` link) starting from given seed nodes.
+ Optionally include seeds or not, and walk breadth first.
+ """
+ return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).flows_to, include_seeds=include_seeds)
+
+ def get_backend_candidates_for_node(self, node_id: NodeId) -> Union[frozenset[BackendId], None]:
+ """Determine backend candidates for given node id"""
+ # TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated)
+ if self.node(node_id).backend_candidates is not None:
+ # Node has explicit backend candidates listed
+ return self.node(node_id).backend_candidates
+ elif self.node(node_id).depends_on:
+ # Backend support is unset: determine it (as intersection) from upstream nodes
+ return self.get_backend_candidates_for_node_set(self.node(node_id).depends_on)
+ else:
+ return None
+
+ def get_backend_candidates_for_node_set(self, node_ids: Iterable[NodeId]) -> Union[frozenset[BackendId], None]:
+ """
+ Determine backend candidates for a set of nodes
+ """
+ candidates = set(self.get_backend_candidates_for_node(n) for n in node_ids)
+ if candidates == {None}:
+ return None
+ candidates.discard(None)
+ return functools.reduce(lambda a, b: a.intersection(b), candidates)
+
+ def find_forsaken_nodes(self) -> Set[NodeId]:
+ """
+ Find nodes that have no backend candidates to process them
+ """
+ return set(
+ node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates_for_node(node_id) == set()
+ )
+
+ def find_articulation_points(self) -> Set[NodeId]:
+ """
+ Find articulation points (cut vertices) in the directed graph:
+ nodes that when removed would split the graph into multiple sub-graphs.
+
+ Note that, unlike in traditional graph theory, the search also includes leaf nodes
+ (e.g. nodes with no parents), as in this context of openEO graph splitting,
+ when we "cut" a node, we replace it with two disconnected new nodes
+ (one connecting to the original parents and one connecting to the original children).
+ """
+ # Approach: label the start nodes (e.g. load_collection) with their id and weight 1.
+ # Propagate these labels along the depends-on links, but split/sum the weight according
+ # to the number of children/parents.
+ # At the end: the articulation points are the nodes where all flows have weight 1.
+
+ # Mapping: node_id -> start_node_id -> flow_weight
+ flow_weights: Dict[NodeId, Dict[NodeId, fractions.Fraction]] = {}
+
+ # Initialize at the pure input nodes (nodes with no upstream dependencies)
+ for node_id, node in self.iter_nodes():
+ if not node.depends_on:
+ flow_weights[node_id] = {node_id: fractions.Fraction(1, 1)}
+
+ # Propagate flow weights using recursion + caching
+ def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]:
+ nonlocal flow_weights
+ if node_id not in flow_weights:
+ flow_weights[node_id] = {}
+ # Calculate from upstream nodes
+ for upstream in self.node(node_id).depends_on:
+ for start_node_id, weight in get_flow_weights(upstream).items():
+ flow_weights[node_id].setdefault(start_node_id, fractions.Fraction(0, 1))
+ flow_weights[node_id][start_node_id] += weight / len(self.node(upstream).flows_to)
+ return flow_weights[node_id]
+
+ for node_id, node in self.iter_nodes():
+ get_flow_weights(node_id)
+
+ # Select articulation points: nodes where all flows have weight 1
+ return set(node_id for node_id, flows in flow_weights.items() if all(w == 1 for w in flows.values()))
+
+ def split_at(self, split_node_id: NodeId) -> Tuple[_GraphViewer, _GraphViewer]:
+ """
+ Split graph at given node id (must be articulation point),
+ creating two new graph viewers, containing original nodes and adaptation of the split node.
+
+ :return: two _GraphViewer objects: the upstream subgraph and the downstream subgraph
+ """
+ split_node = self.node(split_node_id)
+
+ # Walk the graph, upstream from the split node
+ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:
+ node = self.node(node_id)
+ if node_id == split_node_id:
+ return node.depends_on
+ else:
+ return node.depends_on.union(node.flows_to)
+
+ up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes))
+
+ if split_node.flows_to.intersection(up_node_ids):
+ raise GraphSplitException(f"Graph can not be split at {split_node_id!r}: not an articulation point.")
+
+ up_graph = {n: self.node(n) for n in up_node_ids}
+ # Replacement of original split node: no `flows_to` links
+ up_graph[split_node_id] = _GVNode(
+ depends_on=split_node.depends_on,
+ backend_candidates=split_node.backend_candidates,
+ )
+ up = _GraphViewer(node_map=up_graph)
+
+ down_graph = {n: node for n, node in self.iter_nodes() if n not in up_node_ids}
+ # Replacement of original split node: no `depends_on` links
+ # and perhaps more importantly: do not copy over the original `backend_candidates``
+ down_graph[split_node_id] = _GVNode(
+ flows_to=split_node.flows_to,
+ backend_candidates=None,
+ )
+ down = _GraphViewer(node_map=down_graph)
+
+ return up, down
+
+ def produce_split_locations(
+ self,
+ limit: int = 20,
+ allow_split: Callable[[NodeId], bool] = lambda n: True,
+ ) -> Iterator[List[NodeId]]:
+ """
+ Produce disjoint subgraphs that can be processed independently.
+
+ :param limit: maximum number of split locations to produce
+ :param allow_split: predicate to determine if a node can be split on (e.g. to deny splitting on certain nodes)
+
+ :return: iterator of node listings.
+ Each node listing encodes a graph split (nodes ids where to split).
+ A node listing is ordered with the following in mind:
+ - the first node id does a first split in a downstream and upstream part.
+ The upstream part can be handled by a single backend.
+ The downstream part is not necessarily covered by a single backend,
+ in which case one or more additional splits will be necessary.
+ - the second node id does a second split of the downstream part of
+ the previous split.
+ - etc
+ """
+ # TODO: allow_split: possible to make this backend-dependent in some way?
+
+ # Find nodes that have empty set of backend_candidates
+ forsaken_nodes = self.find_forsaken_nodes()
+
+ if forsaken_nodes:
+ # Sort forsaken nodes (based on forsaken parent count), to start higher up the graph
+ forsaken_nodes = sorted(
+ forsaken_nodes, key=lambda n: (sum(p in forsaken_nodes for p in self.node(n).depends_on), n)
+ )
+ _log.debug(f"_GraphViewer.produce_split_locations: {forsaken_nodes=}")
+
+ # Collect nodes where we could split the graph in disjoint subgraphs
+ articulation_points: Set[NodeId] = set(self.find_articulation_points())
+ _log.debug(f"_GraphViewer.produce_split_locations: {articulation_points=}")
+
+ # Walk upstream from forsaken nodes to find articulation points, where we can cut
+ split_options = [
+ n
+ for n in self.walk_upstream_nodes(seeds=forsaken_nodes, include_seeds=False)
+ if n in articulation_points and allow_split(n)
+ ]
+ _log.debug(f"_GraphViewer.produce_split_locations: {split_options=}")
+ if not split_options:
+ raise GraphSplitException("No split options found.")
+ # TODO: Do we really need a limit? Or is there a practical scalability risk to list all possibilities?
+ assert limit > 0
+ for split_node_id in split_options[:limit]:
+ _log.debug(f"_GraphViewer.produce_split_locations: splitting at {split_node_id=}")
+ up, down = self.split_at(split_node_id)
+ # The upstream part should now be handled by a single backend
+ assert not up.find_forsaken_nodes()
+ # Recursively split downstream part if necessary
+ if down.find_forsaken_nodes():
+ down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1), allow_split=allow_split))
+ else:
+ down_splits = [[]]
+
+ for down_split in down_splits:
+ yield [split_node_id] + down_split
+
+ else:
+ # All nodes can be handled as is, no need to split
+ yield []
+
+ def split_at_multiple(self, split_nodes: List[NodeId]) -> Dict[Union[NodeId, None], _GraphViewer]:
+ """
+ Split the graph viewer at multiple nodes in the order as provided.
+ Each split produces an upstream and downstream graph,
+ the downstream graph is used for the next split,
+ so the split nodes should be ordered as such.
+
+ Returns dictionary with:
+ - key: split node_ids or None for the final downstream graph
+ - value: corresponding sub graph viewers as values.
+ """
+ result = {}
+ graph_to_split = self
+ for split_node_id in split_nodes:
+ up, down = graph_to_split.split_at(split_node_id=split_node_id)
+ result[split_node_id] = up
+ graph_to_split = down
+ result[None] = graph_to_split
+ return result
+
+
+class DeepGraphSplitter(ProcessGraphSplitterInterface):
+ """
+ More advanced graph splitting (compared to just splitting off `load_collection` nodes)
+
+ :param split_deny_list: list of process ids that should not be split on
+ """
+
+ def __init__(
+ self,
+ supporting_backends: SupportingBackendsMapper,
+ primary_backend: Optional[BackendId] = None,
+ split_deny_list: Iterable[ProcessId] = (),
+ ):
+ self._supporting_backends_mapper = supporting_backends
+ self._primary_backend = primary_backend
+ # TODO also support other deny mechanisms, e.g. callable instead of a deny list?
+ self._split_deny_list = set(split_deny_list)
+
+ def _pick_backend(self, backend_candidates: Union[frozenset[BackendId], None]) -> BackendId:
+ if backend_candidates is None:
+ if self._primary_backend:
+ return self._primary_backend
+ else:
+ raise GraphSplitException("DeepGraphSplitter._pick_backend: No backend candidates.")
+ else:
+ # TODO: better backend selection mechanism
+ return sorted(backend_candidates)[0]
+
+ def split(self, process_graph: FlatPG) -> _PGSplitResult:
+ graph = _GraphViewer.from_flat_graph(
+ flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper
+ )
+
+ def allow_split(node_id: NodeId) -> bool:
+ process_id = deep_get(process_graph, node_id, "process_id", default=None)
+ return process_id not in self._split_deny_list
+
+ for split_nodes in graph.produce_split_locations(allow_split=allow_split):
+ _log.debug(f"DeepGraphSplitter.split: evaluating split nodes: {split_nodes=}")
+
+ split_views = graph.split_at_multiple(split_nodes=split_nodes)
+
+ # Extract nodes and backend ids for each subgraph
+ subgraph_node_ids = {k: set(n for n, _ in v.iter_nodes()) for k, v in split_views.items()}
+ subgraph_backend_ids = {
+ k: self._pick_backend(backend_candidates=v.get_backend_candidates_for_node_set(subgraph_node_ids[k]))
+ for k, v in split_views.items()
+ }
+ _log.debug(f"DeepGraphSplitter.split: {subgraph_node_ids=} {subgraph_backend_ids=}")
+
+ # Handle primary graph
+ split_views.pop(None)
+ primary_node_ids = subgraph_node_ids.pop(None)
+ primary_backend_id = subgraph_backend_ids.pop(None)
+
+ # Handle secondary graphs
+ secondary_graphs = [
+ _PGSplitSubGraph(split_node=k, node_ids=subgraph_node_ids[k], backend_id=subgraph_backend_ids[k])
+ for k in split_views.keys()
+ ]
+
+ if self._primary_backend is None or primary_backend_id == self._primary_backend:
+ _log.debug(f"DeepGraphSplitter.split: current split matches constraints")
+ return _PGSplitResult(
+ primary_node_ids=primary_node_ids,
+ primary_backend_id=primary_backend_id,
+ secondary_graphs=secondary_graphs,
+ )
+
+ raise GraphSplitException("DeepGraphSplitter.split: No matching split found.")
diff --git a/src/openeo_aggregator/partitionedjobs/tracking.py b/src/openeo_aggregator/partitionedjobs/tracking.py
index a34f36e..26ce77a 100644
--- a/src/openeo_aggregator/partitionedjobs/tracking.py
+++ b/src/openeo_aggregator/partitionedjobs/tracking.py
@@ -24,7 +24,7 @@
SubJob,
)
from openeo_aggregator.partitionedjobs.crossbackend import (
- CrossBackendSplitter,
+ CrossBackendJobSplitter,
SubGraphId,
)
from openeo_aggregator.partitionedjobs.splitting import TileGridSplitter
@@ -71,7 +71,7 @@ def create_crossbackend_pjob(
process: PGWithMetadata,
metadata: dict,
job_options: Optional[dict] = None,
- splitter: CrossBackendSplitter,
+ splitter: CrossBackendJobSplitter,
) -> str:
"""
crossbackend partitioned job creation is different from original partitioned
diff --git a/src/openeo_aggregator/testing.py b/src/openeo_aggregator/testing.py
index 6cf60d8..9f044b8 100644
--- a/src/openeo_aggregator/testing.py
+++ b/src/openeo_aggregator/testing.py
@@ -66,13 +66,13 @@ def exists(self, path):
def get(self, path):
self._assert_open()
if path not in self.data:
- raise kazoo.exceptions.NoNodeError()
+ raise kazoo.exceptions.NoNodeError(path)
return self.data[path]
def get_children(self, path):
self._assert_open()
if path not in self.data:
- raise kazoo.exceptions.NoNodeError()
+ raise kazoo.exceptions.NoNodeError(path)
parent = path.split("/")
return [p.split("/")[-1] for p in self.data if p.split("/")[:-1] == parent]
@@ -103,6 +103,8 @@ def approx_now(abs=10):
class ApproxStr:
"""Pytest helper in style of `pytest.approx`, but for string checking, based on prefix, body and or suffix"""
+ # TODO: port to dirty_equals
+
def __init__(
self,
prefix: Optional[str] = None,
diff --git a/tests/partitionedjobs/test_api.py b/tests/partitionedjobs/test_api.py
index aaf1946..956cae6 100644
--- a/tests/partitionedjobs/test_api.py
+++ b/tests/partitionedjobs/test_api.py
@@ -52,7 +52,7 @@ def __init__(self, date: str):
def dummy1(backend1, requests_mock) -> DummyBackend:
# TODO: rename this fixture to dummy_backed1 for clarity
dummy = DummyBackend(requests_mock=requests_mock, backend_url=backend1, job_id_template="1-jb-{i}")
- dummy.setup_basic_requests_mocks()
+ dummy.setup_basic_requests_mocks(collections=["S1", "S2"])
dummy.register_user(bearer_token=TEST_USER_BEARER_TOKEN, user_id=TEST_USER)
return dummy
@@ -61,7 +61,7 @@ def dummy1(backend1, requests_mock) -> DummyBackend:
def dummy2(backend2, requests_mock) -> DummyBackend:
# TODO: rename this fixture to dummy_backed2 for clarity
dummy = DummyBackend(requests_mock=requests_mock, backend_url=backend2, job_id_template="2-jb-{i}")
- dummy.setup_basic_requests_mocks(collections=["S22"])
+ dummy.setup_basic_requests_mocks(collections=["T11", "T22"])
dummy.register_user(bearer_token=TEST_USER_BEARER_TOKEN, user_id=TEST_USER)
return dummy
@@ -685,17 +685,28 @@ def _partitioned_job_tracking(self, zk_client):
yield
@now.mock
- def test_create_job_simple(self, flask_app, api100, zk_db, dummy1):
+ @pytest.mark.parametrize(
+ "split_strategy",
+ [
+ "crossbackend",
+ {"crossbackend": {"method": "simple"}},
+ {"crossbackend": {"method": "deep"}},
+ ],
+ )
+ def test_create_job_simple(self, flask_app, api100, zk_db, dummy1, split_strategy):
"""Handling of single "load_collection" process graph"""
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
- pg = {"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}, "result": True}}
+ pg = {
+ # lc1 (that's it, that's the graph)
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}, "result": True}
+ }
res = api100.post(
"/jobs",
json={
"process": {"process_graph": pg},
- "job_options": {"split_strategy": "crossbackend"},
+ "job_options": {"split_strategy": split_strategy},
},
).assert_status_code(201)
@@ -719,7 +730,7 @@ def test_create_job_simple(self, flask_app, api100, zk_db, dummy1):
"created": self.now.epoch,
"process": {"process_graph": pg},
"metadata": {"log_level": "info"},
- "job_options": {"split_strategy": "crossbackend"},
+ "job_options": {"split_strategy": split_strategy},
"result_jobs": ["main"],
}
@@ -753,12 +764,23 @@ def test_create_job_simple(self, flask_app, api100, zk_db, dummy1):
assert pg == {"lc1": {"arguments": {"id": "S2"}, "process_id": "load_collection", "result": True}}
@now.mock
- def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock):
+ @pytest.mark.parametrize(
+ "split_strategy",
+ [
+ "crossbackend",
+ {"crossbackend": {"method": "simple"}},
+ {"crossbackend": {"method": "deep", "primary_backend": "b1"}},
+ ],
+ )
+ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, dummy2, requests_mock, split_strategy):
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
pg = {
+ # lc1 lc2
+ # \ /
+ # merge
"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
- "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}},
"merge": {
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
@@ -767,13 +789,16 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
}
requests_mock.get(
- "https://b1.test/v1/jobs/1-jb-0/results?partial=true",
- json={"links": [{"rel": "canonical", "href": "https://data.b1.test/123abc"}]},
+ "https://b2.test/v1/jobs/2-jb-0/results?partial=true",
+ json={"links": [{"rel": "canonical", "href": "https://data.b2.test/123abc"}]},
)
res = api100.post(
"/jobs",
- json={"process": {"process_graph": pg}, "job_options": {"split_strategy": "crossbackend"}},
+ json={
+ "process": {"process_graph": pg},
+ "job_options": {"split_strategy": split_strategy},
+ },
).assert_status_code(201)
pjob_id = "pj-20220119-123456"
@@ -796,7 +821,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
"created": self.now.epoch,
"process": {"process_graph": pg},
"metadata": {"log_level": "info"},
- "job_options": {"split_strategy": "crossbackend"},
+ "job_options": {"split_strategy": split_strategy},
"result_jobs": ["main"],
}
@@ -810,17 +835,17 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
# Inspect stored subjob metadata
subjobs = zk_db.list_subjobs(user_id=TEST_USER, pjob_id=pjob_id)
assert subjobs == {
- "b1:lc2": {
- "backend_id": "b1",
+ "b2:lc2": {
+ "backend_id": "b2",
"process_graph": {
- "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
- "sr1": {
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}},
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
},
},
- "title": "Partitioned job pjob_id='pj-20220119-123456' sjob_id='b1:lc2'",
+ "title": "Partitioned job pjob_id='pj-20220119-123456' sjob_id='b2:lc2'",
},
"main": {
"backend_id": "b1",
@@ -828,7 +853,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"lc2": {
"process_id": "load_stac",
- "arguments": {"url": "https://data.b1.test/123abc"},
+ "arguments": {"url": "https://data.b2.test/123abc"},
},
"merge": {
"process_id": "merge_cubes",
@@ -841,7 +866,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
}
sjob_id = "main"
- expected_job_id = "1-jb-1"
+ expected_job_id = "1-jb-0"
assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == {
"status": "created",
"timestamp": self.now.epoch,
@@ -853,7 +878,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"lc2": {
"process_id": "load_stac",
- "arguments": {"url": "https://data.b1.test/123abc"},
+ "arguments": {"url": "https://data.b2.test/123abc"},
},
"merge": {
"process_id": "merge_cubes",
@@ -862,18 +887,18 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
},
}
- sjob_id = "b1:lc2"
- expected_job_id = "1-jb-0"
+ sjob_id = "b2:lc2"
+ expected_job_id = "2-jb-0"
assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == {
"status": "created",
"timestamp": self.now.epoch,
"message": None,
}
assert zk_db.get_backend_job_id(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == expected_job_id
- assert dummy1.get_job_status(TEST_USER, expected_job_id) == "created"
- assert dummy1.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == {
- "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
- "sr1": {
+ assert dummy2.get_job_status(TEST_USER, expected_job_id) == "created"
+ assert dummy2.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == {
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}},
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
@@ -881,13 +906,24 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock)
}
@now.mock
- def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_mock):
+ @pytest.mark.parametrize(
+ "split_strategy",
+ [
+ "crossbackend",
+ {"crossbackend": {"method": "simple"}},
+ {"crossbackend": {"method": "deep", "primary_backend": "b1"}},
+ ],
+ )
+ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, dummy2, requests_mock, split_strategy):
"""Run the jobs and get results"""
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
pg = {
+ # lc1 lc2
+ # \ /
+ # merge
"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
- "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}},
"merge": {
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
@@ -896,15 +932,15 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_
}
requests_mock.get(
- "https://b1.test/v1/jobs/1-jb-0/results?partial=true",
- json={"links": [{"rel": "canonical", "href": "https://data.b1.test/123abc"}]},
+ "https://b2.test/v1/jobs/2-jb-0/results?partial=true",
+ json={"links": [{"rel": "canonical", "href": "https://data.b2.test/123abc"}]},
)
res = api100.post(
"/jobs",
json={
"process": {"process_graph": pg},
- "job_options": {"split_strategy": "crossbackend"},
+ "job_options": {"split_strategy": split_strategy},
},
).assert_status_code(201)
@@ -923,21 +959,21 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_
# start job
api100.post(f"/jobs/{expected_job_id}/results").assert_status_code(202)
- dummy1.set_job_status(TEST_USER, "1-jb-0", status="running")
- dummy1.set_job_status(TEST_USER, "1-jb-1", status="queued")
+ dummy2.set_job_status(TEST_USER, "2-jb-0", status="running")
+ dummy1.set_job_status(TEST_USER, "1-jb-0", status="queued")
res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200)
assert res.json == DictSubSet({"id": expected_job_id, "status": "running", "progress": 0})
# First job is ready
- dummy1.set_job_status(TEST_USER, "1-jb-0", status="finished")
- dummy1.setup_assets(job_id=f"1-jb-0", assets=["1-jb-0-result.tif"])
- dummy1.set_job_status(TEST_USER, "1-jb-1", status="running")
+ dummy2.set_job_status(TEST_USER, "2-jb-0", status="finished")
+ dummy2.setup_assets(job_id=f"2-jb-0", assets=["2-jb-0-result.tif"])
+ dummy1.set_job_status(TEST_USER, "1-jb-0", status="running")
res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200)
assert res.json == DictSubSet({"id": expected_job_id, "status": "running", "progress": 50})
# Main job is ready too
- dummy1.set_job_status(TEST_USER, "1-jb-1", status="finished")
- dummy1.setup_assets(job_id=f"1-jb-1", assets=["1-jb-1-result.tif"])
+ dummy1.set_job_status(TEST_USER, "1-jb-0", status="finished")
+ dummy1.setup_assets(job_id=f"1-jb-0", assets=["1-jb-0-result.tif"])
res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200)
assert res.json == DictSubSet({"id": expected_job_id, "status": "finished", "progress": 100})
@@ -947,10 +983,10 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_
{
"id": expected_job_id,
"assets": {
- "main-1-jb-1-result.tif": {
- "href": "https://b1.test/v1/jobs/1-jb-1/results/1-jb-1-result.tif",
+ "main-1-jb-0-result.tif": {
+ "href": "https://b1.test/v1/jobs/1-jb-0/results/1-jb-0-result.tif",
"roles": ["data"],
- "title": "main-1-jb-1-result.tif",
+ "title": "main-1-jb-0-result.tif",
"type": "application/octet-stream",
},
},
@@ -958,14 +994,25 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_
)
@now.mock
- def test_failing_create(self, flask_app, api100, zk_db, dummy1):
+ @pytest.mark.parametrize(
+ "split_strategy",
+ [
+ "crossbackend",
+ {"crossbackend": {"method": "simple"}},
+ {"crossbackend": {"method": "deep", "primary_backend": "b1"}},
+ ],
+ )
+ def test_failing_create(self, flask_app, api100, zk_db, dummy1, dummy2, split_strategy):
"""Run what happens when creation of sub batch job fails on upstream backend"""
api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
- dummy1.fail_create_job = True
+ dummy2.fail_create_job = True
pg = {
+ # lc1 lc2
+ # \ /
+ # merge
"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
- "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}},
"merge": {
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
@@ -977,7 +1024,7 @@ def test_failing_create(self, flask_app, api100, zk_db, dummy1):
"/jobs",
json={
"process": {"process_graph": pg},
- "job_options": {"split_strategy": "crossbackend"},
+ "job_options": {"split_strategy": split_strategy},
},
).assert_status_code(201)
@@ -993,3 +1040,138 @@ def test_failing_create(self, flask_app, api100, zk_db, dummy1):
"created": self.now.rfc3339,
"progress": 0,
}
+
+ @now.mock
+ def test_create_job_deep_basic(self, flask_app, api100, zk_db, dummy1, dummy2, requests_mock):
+ api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN)
+
+ pg = {
+ # lc1 lc2
+ # | |
+ # bands1 temporal2
+ # \ /
+ # merge
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}},
+ "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}}},
+ "temporal2": {"process_id": "filter_temporal", "arguments": {"data": {"from_node": "lc2"}}},
+ "merge": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}},
+ "result": True,
+ },
+ }
+
+ requests_mock.get(
+ "https://b2.test/v1/jobs/2-jb-0/results?partial=true",
+ json={"links": [{"rel": "canonical", "href": "https://data.b2.test/123abc"}]},
+ )
+
+ split_strategy = {"crossbackend": {"method": "deep", "primary_backend": "b1"}}
+ res = api100.post(
+ "/jobs",
+ json={
+ "process": {"process_graph": pg},
+ "job_options": {"split_strategy": split_strategy},
+ },
+ ).assert_status_code(201)
+
+ pjob_id = "pj-20220119-123456"
+ expected_job_id = f"agg-{pjob_id}"
+ assert res.headers["Location"] == f"http://oeoa.test/openeo/1.0.0/jobs/{expected_job_id}"
+ assert res.headers["OpenEO-Identifier"] == expected_job_id
+
+ res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200)
+ assert res.json == {
+ "id": expected_job_id,
+ "process": {"process_graph": pg},
+ "status": "created",
+ "created": self.now.rfc3339,
+ "progress": 0,
+ }
+
+ # Inspect stored parent job metadata
+ assert zk_db.get_pjob_metadata(user_id=TEST_USER, pjob_id=pjob_id) == {
+ "user_id": TEST_USER,
+ "created": self.now.epoch,
+ "process": {"process_graph": pg},
+ "metadata": {"log_level": "info"},
+ "job_options": {"split_strategy": split_strategy},
+ "result_jobs": ["main"],
+ }
+
+ assert zk_db.get_pjob_status(user_id=TEST_USER, pjob_id=pjob_id) == {
+ "status": "created",
+ "message": approx_str_contains("{'created': 2}"),
+ "timestamp": pytest.approx(self.now.epoch, abs=5),
+ "progress": 0,
+ }
+
+ # Inspect stored subjob metadata
+ subjobs = zk_db.list_subjobs(user_id=TEST_USER, pjob_id=pjob_id)
+ assert subjobs == {
+ "b2:temporal2": {
+ "backend_id": "b2",
+ "process_graph": {
+ "lc2": {"arguments": {"id": "T22"}, "process_id": "load_collection"},
+ "temporal2": {"arguments": {"data": {"from_node": "lc2"}}, "process_id": "filter_temporal"},
+ "_agg_crossbackend_save_result": {
+ "arguments": {"data": {"from_node": "temporal2"}, "format": "GTiff"},
+ "process_id": "save_result",
+ "result": True,
+ },
+ },
+ "title": "Partitioned job pjob_id='pj-20220119-123456' sjob_id='b2:temporal2'",
+ },
+ "main": {
+ "backend_id": "b1",
+ "process_graph": {
+ "lc1": {"arguments": {"id": "S2"}, "process_id": "load_collection"},
+ "bands1": {"arguments": {"data": {"from_node": "lc1"}}, "process_id": "filter_bands"},
+ "temporal2": {"arguments": {"url": "https://data.b2.test/123abc"}, "process_id": "load_stac"},
+ "merge": {
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}},
+ "process_id": "merge_cubes",
+ "result": True,
+ },
+ },
+ "title": "Partitioned job pjob_id='pj-20220119-123456' " "sjob_id='main'",
+ },
+ }
+ sjob_id = "main"
+ expected_job_id = "1-jb-0"
+ assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == {
+ "status": "created",
+ "timestamp": self.now.epoch,
+ "message": None,
+ }
+ assert zk_db.get_backend_job_id(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == expected_job_id
+ assert dummy1.get_job_status(TEST_USER, expected_job_id) == "created"
+ assert dummy1.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == {
+ "lc1": {"arguments": {"id": "S2"}, "process_id": "load_collection"},
+ "bands1": {"arguments": {"data": {"from_node": "lc1"}}, "process_id": "filter_bands"},
+ "temporal2": {"arguments": {"url": "https://data.b2.test/123abc"}, "process_id": "load_stac"},
+ "merge": {
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}},
+ "process_id": "merge_cubes",
+ "result": True,
+ },
+ }
+ sjob_id = "b2:temporal2"
+ expected_job_id = "2-jb-0"
+ assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == {
+ "status": "created",
+ "timestamp": self.now.epoch,
+ "message": None,
+ }
+ assert zk_db.get_backend_job_id(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == expected_job_id
+ assert dummy2.get_job_status(TEST_USER, expected_job_id) == "created"
+ assert dummy2.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == {
+ "lc2": {"arguments": {"id": "T22"}, "process_id": "load_collection"},
+ "temporal2": {"arguments": {"data": {"from_node": "lc2"}}, "process_id": "filter_temporal"},
+ "_agg_crossbackend_save_result": {
+ "arguments": {"data": {"from_node": "temporal2"}, "format": "GTiff"},
+ "process_id": "save_result",
+ "result": True,
+ },
+ }
diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py
index b4b6239..adf0e65 100644
--- a/tests/partitionedjobs/test_crossbackend.py
+++ b/tests/partitionedjobs/test_crossbackend.py
@@ -13,8 +13,16 @@
from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob
from openeo_aggregator.partitionedjobs.crossbackend import (
- CrossBackendSplitter,
+ CrossBackendJobSplitter,
+ DeepGraphSplitter,
+ GraphSplitException,
+ LoadCollectionGraphSplitter,
SubGraphId,
+ SupportingBackendsMapper,
+ _GraphViewer,
+ _GVNode,
+ _PGSplitResult,
+ _PGSplitSubGraph,
run_partitioned_job,
)
@@ -22,7 +30,9 @@
class TestCrossBackendSplitter:
def test_split_simple(self):
process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
- splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
+ splitter = CrossBackendJobSplitter(
+ graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo")
+ )
res = splitter.split({"process_graph": process_graph})
assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)}
@@ -30,7 +40,9 @@ def test_split_simple(self):
def test_split_streaming_simple(self):
process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
- splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
+ splitter = CrossBackendJobSplitter(
+ graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo")
+ )
res = splitter.split_streaming(process_graph)
assert isinstance(res, types.GeneratorType)
assert list(res) == [("main", SubJob(process_graph, backend_id=None), [])]
@@ -46,13 +58,15 @@ def test_split_basic(self):
"cube2": {"from_node": "lc2"},
},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
},
}
- splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ splitter = CrossBackendJobSplitter(
+ graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ )
res = splitter.split({"process_graph": process_graph})
assert res.subjobs == {
@@ -73,7 +87,7 @@ def test_split_basic(self):
"cube2": {"from_node": "lc2"},
},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
@@ -87,7 +101,7 @@ def test_split_basic(self):
"process_id": "load_collection",
"arguments": {"id": "B2_FAPAR"},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
@@ -109,13 +123,15 @@ def test_split_streaming_basic(self):
"cube2": {"from_node": "lc2"},
},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
},
}
- splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ splitter = CrossBackendJobSplitter(
+ graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ )
result = splitter.split_streaming(process_graph)
assert isinstance(result, types.GeneratorType)
@@ -128,7 +144,7 @@ def test_split_streaming_basic(self):
"process_id": "load_collection",
"arguments": {"id": "B2_FAPAR"},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
@@ -148,7 +164,7 @@ def test_split_streaming_basic(self):
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
@@ -175,7 +191,9 @@ def test_split_streaming_get_replacement(self):
"result": True,
},
}
- splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ splitter = CrossBackendJobSplitter(
+ graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ )
batch_jobs = {}
@@ -200,7 +218,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict:
SubJob(
process_graph={
"lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
@@ -215,7 +233,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict:
SubJob(
process_graph={
"lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc3"}, "format": "GTiff"},
"result": True,
@@ -365,13 +383,15 @@ def test_basic(self, aggregator: _FakeAggregator):
"cube2": {"from_node": "lc2"},
},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
},
}
- splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ splitter = CrossBackendJobSplitter(
+ graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
+ )
pjob: PartitionedJob = splitter.split({"process_graph": process_graph})
connection = openeo.Connection(aggregator.url)
@@ -400,7 +420,7 @@ def test_basic(self, aggregator: _FakeAggregator):
"cube2": {"from_node": "lc2"},
},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
"result": True,
@@ -411,9 +431,792 @@ def test_basic(self, aggregator: _FakeAggregator):
"process_id": "load_collection",
"arguments": {"id": "B2_FAPAR"},
},
- "sr1": {
+ "_agg_crossbackend_save_result": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
"result": True,
},
}
+
+
+class TestGVNode:
+ def test_defaults(self):
+ node = _GVNode()
+ assert isinstance(node.depends_on, frozenset)
+ assert node.depends_on == frozenset()
+ assert isinstance(node.flows_to, frozenset)
+ assert node.flows_to == frozenset()
+ assert node.backend_candidates is None
+
+ def test_basic(self):
+ node = _GVNode(depends_on=["a", "b"], flows_to=["c", "d"], backend_candidates=["X"])
+ assert isinstance(node.depends_on, frozenset)
+ assert node.depends_on == frozenset(["a", "b"])
+ assert isinstance(node.flows_to, frozenset)
+ assert node.flows_to == frozenset(["c", "d"])
+ assert isinstance(node.backend_candidates, frozenset)
+ assert node.backend_candidates == frozenset(["X"])
+
+ def test_single_strings(self):
+ node = _GVNode(depends_on="apple", flows_to="banana", backend_candidates="coconut")
+ assert isinstance(node.depends_on, frozenset)
+ assert node.depends_on == frozenset(["apple"])
+ assert isinstance(node.flows_to, frozenset)
+ assert node.flows_to == frozenset(["banana"])
+ assert isinstance(node.backend_candidates, frozenset)
+ assert node.backend_candidates == frozenset(["coconut"])
+
+ def test_eq(self):
+ assert _GVNode() == _GVNode()
+ assert _GVNode(
+ depends_on=["a", "b"],
+ flows_to=["c", "d"],
+ backend_candidates="X",
+ ) == _GVNode(
+ depends_on=("b", "a"),
+ flows_to={"d", "c"},
+ backend_candidates=["X"],
+ )
+
+ def test_repr(self):
+ assert repr(_GVNode()) == "[_GVNode]"
+ assert repr(_GVNode(depends_on="a")) == "[_GVNode b]"
+ assert repr(_GVNode(depends_on=["a", "b"], flows_to=["foo", "bar"])) == "[_GVNode bar,foo]"
+ assert repr(_GVNode(depends_on="a", flows_to="b", backend_candidates=["x", "yy"])) == "[_GVNode b @x,yy]"
+
+
+def supporting_backends_from_node_id_dict(data: dict) -> SupportingBackendsMapper:
+ return lambda node_id, node: data.get(node_id)
+
+
+class TestGraphViewer:
+ def test_empty(self):
+ graph = _GraphViewer(node_map={})
+ assert list(graph.iter_nodes()) == []
+
+ @pytest.mark.parametrize(
+ ["node_map", "expected_error"],
+ [
+ ({"a": _GVNode(flows_to="b")}, r"Inconsistent.*unknown=\{'b'\}"),
+ ({"b": _GVNode(depends_on="a")}, r"Inconsistent.*unknown=\{'a'\}"),
+ ({"a": _GVNode(flows_to="b"), "b": _GVNode()}, r"Inconsistent.*bad_links=\{\('a', 'b'\)\}"),
+ ({"b": _GVNode(depends_on="a"), "a": _GVNode()}, r"Inconsistent.*bad_links=\{\('a', 'b'\)\}"),
+ ],
+ )
+ def test_check_consistency(self, node_map, expected_error):
+ with pytest.raises(GraphSplitException, match=expected_error):
+ _ = _GraphViewer(node_map=node_map)
+
+ def test_immutability(self):
+ node_map = {"a": _GVNode(flows_to="b"), "b": _GVNode(depends_on="a")}
+ graph = _GraphViewer(node_map=node_map)
+ assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))]
+
+ # Adding a node to the original map should not affect the graph
+ node_map["c"] = _GVNode()
+ assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))]
+
+ # Trying to mess with internals shouldn't work either
+ with pytest.raises(Exception, match="does not support item assignment"):
+ graph._graph["c"] = _GVNode()
+
+ assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))]
+
+ def test_from_flat_graph_basic(self):
+ flat = {
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}},
+ "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True},
+ }
+ graph = _GraphViewer.from_flat_graph(
+ flat, supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]})
+ )
+ assert sorted(graph.iter_nodes()) == [
+ ("lc1", _GVNode(flows_to=["ndvi1"], backend_candidates="b1")),
+ ("ndvi1", _GVNode(depends_on=["lc1"])),
+ ]
+
+ # TODO: test from_flat_graph with more complex graphs
+
+ def test_from_edges(self):
+ graph = _GraphViewer.from_edges([("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")])
+ assert sorted(graph.iter_nodes()) == [
+ ("a", _GVNode(flows_to=["c"])),
+ ("b", _GVNode(flows_to=["d"])),
+ ("c", _GVNode(depends_on=["a"], flows_to=["e"])),
+ ("d", _GVNode(depends_on=["b"], flows_to=["e"])),
+ ("e", _GVNode(depends_on=["c", "d"], flows_to=["f"])),
+ ("f", _GVNode(depends_on=["e"])),
+ ]
+
+ @pytest.mark.parametrize(
+ ["seed", "include_seeds", "expected"],
+ [
+ (["a"], True, ["a"]),
+ (["a"], False, []),
+ (["c"], True, ["c", "a"]),
+ (["c"], False, ["a"]),
+ (["a", "c"], True, ["a", "c"]),
+ (["a", "c"], False, []),
+ (["c", "a"], True, ["a", "c"]),
+ (["c", "a"], False, []),
+ (["e"], True, ["e", "c", "d", "a", "b"]),
+ (["e"], False, ["c", "d", "a", "b"]),
+ (["e", "d"], True, ["d", "e", "b", "c", "a"]),
+ (["e", "d"], False, ["c", "b", "a"]),
+ (["d", "e"], True, ["d", "e", "b", "c", "a"]),
+ (["d", "e"], False, ["b", "c", "a"]),
+ (["f", "c"], True, ["c", "f", "a", "e", "d", "b"]),
+ (["f", "c"], False, ["e", "a", "d", "b"]),
+ ],
+ )
+ def test_walk_upstream_nodes(self, seed, include_seeds, expected):
+ graph = _GraphViewer.from_edges(
+ # a b
+ # | |
+ # c d
+ # \ /
+ # e
+ # |
+ # f
+ [("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")]
+ )
+ assert list(graph.walk_upstream_nodes(seed, include_seeds)) == expected
+
+ def test_get_backend_candidates_basic(self):
+ graph = _GraphViewer.from_edges(
+ # a
+ # |
+ # b c
+ # \ /
+ # d
+ [("a", "b"), ("b", "d"), ("c", "d")],
+ supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": ["b1"], "c": ["b2"]}),
+ )
+ assert graph.get_backend_candidates_for_node("a") == {"b1"}
+ assert graph.get_backend_candidates_for_node("b") == {"b1"}
+ assert graph.get_backend_candidates_for_node("c") == {"b2"}
+ assert graph.get_backend_candidates_for_node("d") == set()
+
+ assert graph.get_backend_candidates_for_node_set(["a"]) == {"b1"}
+ assert graph.get_backend_candidates_for_node_set(["b"]) == {"b1"}
+ assert graph.get_backend_candidates_for_node_set(["c"]) == {"b2"}
+ assert graph.get_backend_candidates_for_node_set(["d"]) == set()
+ assert graph.get_backend_candidates_for_node_set(["a", "b"]) == {"b1"}
+ assert graph.get_backend_candidates_for_node_set(["a", "b", "c"]) == set()
+ assert graph.get_backend_candidates_for_node_set(["a", "b", "d"]) == set()
+
+ def test_get_backend_candidates_none(self):
+ graph = _GraphViewer.from_edges(
+ # a
+ # |
+ # b c
+ # \ /
+ # d
+ [("a", "b"), ("b", "d"), ("c", "d")],
+ )
+ assert graph.get_backend_candidates_for_node("a") is None
+ assert graph.get_backend_candidates_for_node("b") is None
+ assert graph.get_backend_candidates_for_node("c") is None
+ assert graph.get_backend_candidates_for_node("d") is None
+
+ assert graph.get_backend_candidates_for_node_set(["a", "b"]) is None
+ assert graph.get_backend_candidates_for_node_set(["a", "b", "c"]) is None
+
+ def test_get_backend_candidates_intersection(self):
+ graph = _GraphViewer.from_edges(
+ # a b c
+ # \ / \ /
+ # d e
+ # \ /
+ # f
+ [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")],
+ supporting_backends_mapper=supporting_backends_from_node_id_dict(
+ {"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}
+ ),
+ )
+ assert graph.get_backend_candidates_for_node("a") == {"b1", "b2"}
+ assert graph.get_backend_candidates_for_node("b") == {"b2", "b3"}
+ assert graph.get_backend_candidates_for_node("c") == {"b4"}
+ assert graph.get_backend_candidates_for_node("d") == {"b2"}
+ assert graph.get_backend_candidates_for_node("e") == set()
+ assert graph.get_backend_candidates_for_node("f") == set()
+
+ assert graph.get_backend_candidates_for_node_set(["a", "b"]) == {"b2"}
+ assert graph.get_backend_candidates_for_node_set(["a", "b", "d"]) == {"b2"}
+ assert graph.get_backend_candidates_for_node_set(["c", "d"]) == set()
+
+ def test_find_forsaken_nodes(self):
+ graph = _GraphViewer.from_edges(
+ # a b c
+ # \ / \ /
+ # d e
+ # \ /
+ # f
+ # / \
+ # g h
+ [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f"), ("f", "g"), ("f", "h")],
+ supporting_backends_mapper=supporting_backends_from_node_id_dict(
+ {"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}
+ ),
+ )
+ assert graph.find_forsaken_nodes() == {"e", "f", "g", "h"}
+
+ def test_find_articulation_points_basic(self):
+ flat = {
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True},
+ }
+ graph = _GraphViewer.from_flat_graph(flat)
+ assert graph.find_articulation_points() == {"lc1", "ndvi1"}
+
+ @pytest.mark.parametrize(
+ ["flat", "expected"],
+ [
+ (
+ {
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}},
+ },
+ {"lc1", "ndvi1"},
+ ),
+ (
+ {
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "bands1": {
+ "process_id": "filter_bands",
+ "arguments": {"data": {"from_node": "lc1"}, "bands": ["b1"]},
+ },
+ "bands2": {
+ "process_id": "filter_bands",
+ "arguments": {"data": {"from_node": "lc1"}, "bands": ["b2"]},
+ },
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}},
+ },
+ "save1": {
+ "process_id": "save_result",
+ "arguments": {"data": {"from_node": "merge1"}, "format": "GTiff"},
+ },
+ },
+ {"lc1", "merge1", "save1"},
+ ),
+ (
+ {
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
+ },
+ "save1": {
+ "process_id": "save_result",
+ "arguments": {"data": {"from_node": "merge1"}, "format": "GTiff"},
+ },
+ },
+ {"lc1", "lc2", "merge1", "save1"},
+ ),
+ (
+ {
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "bands1": {
+ "process_id": "filter_bands",
+ "arguments": {"data": {"from_node": "lc1"}, "bands": ["b1"]},
+ },
+ "bbox1": {
+ "process_id": "filter_spatial",
+ "arguments": {"data": {"from_node": "bands1"}},
+ },
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "bbox1"}},
+ },
+ "save1": {
+ "process_id": "save_result",
+ "arguments": {"data": {"from_node": "merge1"}},
+ },
+ },
+ {"lc1", "merge1", "save1"},
+ ),
+ ],
+ )
+ def test_find_articulation_points(self, flat, expected):
+ graph = _GraphViewer.from_flat_graph(flat)
+ assert graph.find_articulation_points() == expected
+
+ def test_split_at_minimal(self):
+ graph = _GraphViewer.from_edges(
+ [("a", "b")], supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"})
+ )
+ # Split at a
+ up, down = graph.split_at("a")
+ assert sorted(up.iter_nodes()) == [
+ ("a", _GVNode(backend_candidates=["A"])),
+ ]
+ assert sorted(down.iter_nodes()) == [
+ ("a", _GVNode(flows_to=["b"])),
+ ("b", _GVNode(depends_on=["a"])),
+ ]
+ # Split at b
+ up, down = graph.split_at("b")
+ assert sorted(up.iter_nodes()) == [
+ ("a", _GVNode(flows_to=["b"], backend_candidates=["A"])),
+ ("b", _GVNode(depends_on=["a"])),
+ ]
+ assert sorted(down.iter_nodes()) == [
+ ("b", _GVNode()),
+ ]
+
+ def test_split_at_basic(self):
+ graph = _GraphViewer.from_edges(
+ [("a", "b"), ("b", "c")],
+ supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}),
+ )
+ up, down = graph.split_at("b")
+ assert sorted(up.iter_nodes()) == [
+ ("a", _GVNode(flows_to=["b"], backend_candidates=["A"])),
+ ("b", _GVNode(depends_on=["a"])),
+ ]
+ assert sorted(down.iter_nodes()) == [
+ ("b", _GVNode(flows_to=["c"])),
+ ("c", _GVNode(depends_on=["b"])),
+ ]
+
+ def test_split_at_complex(self):
+ graph = _GraphViewer.from_edges(
+ # a
+ # / \
+ # b c X
+ # \ / \ |
+ # d e f Y
+ # \ /
+ # g
+ [("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")]
+ )
+ up, down = graph.split_at("e")
+ assert sorted(up.iter_nodes()) == sorted(
+ _GraphViewer.from_edges([("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e")]).iter_nodes()
+ )
+ assert sorted(down.iter_nodes()) == sorted(
+ _GraphViewer.from_edges([("e", "g"), ("f", "g"), ("X", "Y")]).iter_nodes()
+ )
+
+ def test_split_at_non_articulation_point(self):
+ graph = _GraphViewer.from_edges(
+ # a
+ # /|
+ # b |
+ # \|
+ # c
+ [("a", "b"), ("b", "c"), ("a", "c")]
+ )
+
+ with pytest.raises(GraphSplitException, match="not an articulation point"):
+ _ = graph.split_at("b")
+
+ # These should still work
+ up, down = graph.split_at("a")
+ assert sorted(up.iter_nodes()) == [
+ ("a", _GVNode()),
+ ]
+ assert sorted(down.iter_nodes()) == [
+ ("a", _GVNode(flows_to=["b", "c"])),
+ ("b", _GVNode(depends_on=["a"], flows_to=["c"])),
+ ("c", _GVNode(depends_on=["a", "b"])),
+ ]
+
+ up, down = graph.split_at("c")
+ assert sorted(up.iter_nodes()) == [
+ ("a", _GVNode(flows_to=["b", "c"])),
+ ("b", _GVNode(depends_on=["a"], flows_to=["c"])),
+ ("c", _GVNode(depends_on=["a", "b"])),
+ ]
+ assert sorted(down.iter_nodes()) == [
+ ("c", _GVNode()),
+ ]
+
+ def test_split_at_multiple_empty(self):
+ graph = _GraphViewer.from_edges([("a", "b")])
+ result = graph.split_at_multiple([])
+ assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == {
+ None: [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))],
+ }
+
+ def test_split_at_multiple_single(self):
+ graph = _GraphViewer.from_edges([("a", "b"), ("b", "c")])
+ result = graph.split_at_multiple(["b"])
+ assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == {
+ "b": [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))],
+ None: [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))],
+ }
+
+ def test_split_at_multiple_basic(self):
+ graph = _GraphViewer.from_edges(
+ [("a", "b"), ("b", "c"), ("c", "d")],
+ supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}),
+ )
+ result = graph.split_at_multiple(["b", "c"])
+ assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == {
+ "b": [("a", _GVNode(flows_to="b", backend_candidates="A")), ("b", _GVNode(depends_on="a"))],
+ "c": [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))],
+ None: [("c", _GVNode(flows_to="d")), ("d", _GVNode(depends_on="c"))],
+ }
+
+ def test_split_at_multiple_invalid(self):
+ """Split nodes should be in downstream order"""
+ graph = _GraphViewer.from_edges(
+ [("a", "b"), ("b", "c"), ("c", "d")],
+ )
+ # Downstream order: works
+ _ = graph.split_at_multiple(["b", "c"])
+ # Upstream order: fails
+ with pytest.raises(GraphSplitException, match="Invalid node id 'b'"):
+ _ = graph.split_at_multiple(["c", "b"])
+
+ def test_produce_split_locations_simple(self):
+ """Simple produce_split_locations use case: no need for splits"""
+ flat = {
+ # lc1
+ # |
+ # ndvi1
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True},
+ }
+ graph = _GraphViewer.from_flat_graph(
+ flat, supporting_backends=supporting_backends_from_node_id_dict({"lc1": "b1"})
+ )
+ assert list(graph.produce_split_locations()) == [[]]
+
+ def test_produce_split_locations_merge_basic(self):
+ """
+ Basic produce_split_locations use case:
+ two load collections on different backends and a merge
+ """
+ flat = {
+ # lc1 lc2
+ # \ /
+ # merge1
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
+ },
+ }
+ graph = _GraphViewer.from_flat_graph(
+ flat,
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
+ )
+ assert sorted(graph.produce_split_locations()) == [["lc1"], ["lc2"]]
+
+ def test_produce_split_locations_merge_longer(self):
+ flat = {
+ # lc1 lc2
+ # | |
+ # bands1 bands2
+ # \ /
+ # merge1
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}},
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}},
+ },
+ }
+ graph = _GraphViewer.from_flat_graph(
+ flat,
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
+ )
+ assert sorted(graph.produce_split_locations(limit=2)) == [["bands1"], ["bands2"]]
+ assert list(graph.produce_split_locations(limit=4)) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]]
+
+ def test_produce_split_locations_merge_longer_triangle(self):
+ flat = {
+ # lc1
+ # / |
+ # bands1 | lc2
+ # \ | |
+ # mask1 bands2
+ # \ /
+ # merge1
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
+ "mask1": {
+ "process_id": "mask",
+ "arguments": {"data": {"from_node": "bands1"}, "mask": {"from_node": "lc1"}},
+ },
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}},
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "mask1"}, "cube2": {"from_node": "bands2"}},
+ },
+ }
+ graph = _GraphViewer.from_flat_graph(
+ flat,
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
+ )
+ assert list(graph.produce_split_locations(limit=4)) == [["bands2"], ["mask1"], ["lc2"], ["lc1"]]
+
+ def test_produce_split_locations_allow_split(self):
+ """Usage of custom allow_list predicate"""
+ flat = {
+ # lc1 lc2
+ # | |
+ # bands1 bands2
+ # \ /
+ # merge1
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}},
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}},
+ },
+ }
+ graph = _GraphViewer.from_flat_graph(
+ flat,
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
+ )
+ assert list(graph.produce_split_locations()) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]]
+ assert list(graph.produce_split_locations(allow_split=lambda n: n not in {"bands1", "lc2"})) == [
+ ["bands2"],
+ ["lc1"],
+ ]
+
+
+class TestDeepGraphSplitter:
+ def test_no_split(self):
+ splitter = DeepGraphSplitter(supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]}))
+ flat = {
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True},
+ }
+ result = splitter.split(flat)
+ assert result == _PGSplitResult(
+ primary_node_ids={"lc1", "ndvi1"},
+ primary_backend_id="b1",
+ secondary_graphs=[],
+ )
+
+ def test_simple_split(self):
+ """
+ Most simple split use case: two load_collections from different backends, merged.
+ """
+ splitter = DeepGraphSplitter(
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]})
+ )
+ flat = {
+ # lc1 lc2
+ # \ /
+ # merge
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "merge": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
+ "result": True,
+ },
+ }
+ result = splitter.split(flat)
+ assert result == _PGSplitResult(
+ primary_node_ids={"lc1", "lc2", "merge"},
+ primary_backend_id="b2",
+ secondary_graphs=[
+ _PGSplitSubGraph(
+ split_node="lc1",
+ node_ids={"lc1"},
+ backend_id="b1",
+ )
+ ],
+ )
+
+ def test_simple_deep_split(self):
+ """
+ Simple deep split use case:
+ two load_collections from different backends, with some additional filtering, merged.
+ """
+ splitter = DeepGraphSplitter(
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]})
+ )
+ flat = {
+ # lc1 lc2
+ # | |
+ # bands1 temporal2
+ # \ /
+ # merge
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
+ "temporal2": {
+ "process_id": "filter_temporal",
+ "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]},
+ },
+ "merge": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}},
+ "result": True,
+ },
+ }
+ result = splitter.split(flat)
+ assert result == _PGSplitResult(
+ primary_node_ids={"lc2", "temporal2", "bands1", "merge"},
+ primary_backend_id="b2",
+ secondary_graphs=[_PGSplitSubGraph(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")],
+ )
+
+ def test_shallow_triple_split(self):
+ splitter = DeepGraphSplitter(
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]})
+ )
+ flat = {
+ # lc1 lc2 lc3
+ # \ / /
+ # merge1 /
+ # \ /
+ # merge2
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "lc3": {"process_id": "load_collection", "arguments": {"id": "S3"}},
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
+ },
+ "merge2": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "merge1"}, "cube2": {"from_node": "lc3"}},
+ "result": True,
+ },
+ }
+ result = splitter.split(flat)
+ assert result == _PGSplitResult(
+ primary_node_ids={"lc1", "lc2", "lc3", "merge1", "merge2"},
+ primary_backend_id="b2",
+ secondary_graphs=[
+ _PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1"),
+ _PGSplitSubGraph(split_node="lc3", node_ids={"lc3"}, backend_id="b3"),
+ ],
+ )
+
+ def test_triple_split(self):
+ splitter = DeepGraphSplitter(
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]})
+ )
+ flat = {
+ # lc1 lc2 lc3
+ # | | |
+ # bands1 temporal2 spatial3
+ # \ / /
+ # merge1 /
+ # \ /
+ # merge2
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "lc3": {"process_id": "load_collection", "arguments": {"id": "S3"}},
+ "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
+ "temporal2": {
+ "process_id": "filter_temporal",
+ "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]},
+ },
+ "spatial3": {"process_id": "filter_spatial", "arguments": {"data": {"from_node": "lc3"}, "extent": "EU"}},
+ "merge1": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}},
+ },
+ "merge2": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "merge1"}, "cube2": {"from_node": "spatial3"}},
+ "result": True,
+ },
+ }
+ result = splitter.split(flat)
+ assert result == _PGSplitResult(
+ primary_node_ids={"merge2", "merge1", "lc3", "spatial3"},
+ primary_backend_id="b3",
+ secondary_graphs=[
+ _PGSplitSubGraph(split_node="bands1", node_ids={"bands1", "lc1"}, backend_id="b1"),
+ _PGSplitSubGraph(
+ split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2"
+ ),
+ ],
+ )
+
+ @pytest.mark.parametrize(
+ ["primary_backend", "secondary_graph"],
+ [
+ ("b1", _PGSplitSubGraph(split_node="lc2", node_ids={"lc2"}, backend_id="b2")),
+ ("b2", _PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1")),
+ ],
+ )
+ def test_split_with_primary_backend(self, primary_backend, secondary_graph):
+ """Test `primary_backend` argument of DeepGraphSplitter"""
+ splitter = DeepGraphSplitter(
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
+ primary_backend=primary_backend,
+ )
+ flat = {
+ # lc1 lc2
+ # \ /
+ # merge
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "merge": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
+ "result": True,
+ },
+ }
+ result = splitter.split(flat)
+ assert result == _PGSplitResult(
+ primary_node_ids={"lc1", "lc2", "merge"},
+ primary_backend_id=primary_backend,
+ secondary_graphs=[secondary_graph],
+ )
+
+ @pytest.mark.parametrize(
+ ["split_deny_list", "split_node", "primary_node_ids", "secondary_node_ids"],
+ [
+ ({}, "temporal2", {"lc1", "bands1", "temporal2", "merge"}, {"lc2", "temporal2"}),
+ ({"filter_bands", "filter_temporal"}, "lc2", {"lc1", "lc2", "bands1", "temporal2", "merge"}, {"lc2"}),
+ ],
+ )
+ def test_split_deny_list(self, split_deny_list, split_node, primary_node_ids, secondary_node_ids):
+ """
+ Simple deep split use case:
+ two load_collections from different backends, with some additional filtering, merged.
+ """
+ splitter = DeepGraphSplitter(
+ supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
+ primary_backend="b1",
+ split_deny_list=split_deny_list,
+ )
+ flat = {
+ # lc1 lc2
+ # | |
+ # bands1 temporal2
+ # \ /
+ # merge
+ "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
+ "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
+ "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
+ "temporal2": {
+ "process_id": "filter_temporal",
+ "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]},
+ },
+ "merge": {
+ "process_id": "merge_cubes",
+ "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}},
+ "result": True,
+ },
+ }
+ result = splitter.split(flat)
+ assert result == _PGSplitResult(
+ primary_node_ids=primary_node_ids,
+ primary_backend_id="b1",
+ secondary_graphs=[_PGSplitSubGraph(split_node=split_node, node_ids=secondary_node_ids, backend_id="b2")],
+ )