diff --git a/CHANGELOG.md b/CHANGELOG.md index d192b98..f68f61c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is roughly based on [Keep a Changelog](https://keepachangelog.com/en/ +## 0.39.0 + +- More advanced process-graph splitting for cross-backend execution: not limited to splitting off `load_collection` nodes, but cut deeper into the graph. ([#150](https://github.com/Open-EO/openeo-aggregator/issues/150)) + ## 0.38.0 - Add request timeout configs for listing user jobs (eu-cdse/openeo-cdse-infra#188) diff --git a/scripts/crossbackend-processing-poc.py b/scripts/crossbackend-processing-poc.py index e8b83e5..01f6a33 100644 --- a/scripts/crossbackend-processing-poc.py +++ b/scripts/crossbackend-processing-poc.py @@ -7,7 +7,8 @@ from openeo_aggregator.metadata import STAC_PROPERTY_FEDERATION_BACKENDS from openeo_aggregator.partitionedjobs import PartitionedJob from openeo_aggregator.partitionedjobs.crossbackend import ( - CrossBackendSplitter, + CrossBackendJobSplitter, + LoadCollectionGraphSplitter, run_partitioned_job, ) @@ -62,7 +63,9 @@ def backend_for_collection(collection_id) -> str: metadata = connection.describe_collection(collection_id) return metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS][0] - splitter = CrossBackendSplitter(backend_for_collection=backend_for_collection, always_split=True) + splitter = CrossBackendJobSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=backend_for_collection, always_split=True) + ) pjob: PartitionedJob = splitter.split({"process_graph": process_graph}) _log.info(f"Partitioned job: {pjob!r}") diff --git a/src/openeo_aggregator/about.py b/src/openeo_aggregator/about.py index 5287e3d..febd27f 100644 --- a/src/openeo_aggregator/about.py +++ b/src/openeo_aggregator/about.py @@ -2,7 +2,7 @@ import sys from typing import Optional -__version__ = "0.38.0a1" +__version__ = "0.39.0a1" def log_version_info(logger: Optional[logging.Logger] = None): diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index b568ce3..398aaeb 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -84,6 +84,7 @@ streaming_flask_response, ) from openeo_aggregator.constants import ( + CROSSBACKEND_GRAPH_SPLIT_METHOD, JOB_OPTION_FORCE_BACKEND, JOB_OPTION_SPLIT_STRATEGY, JOB_OPTION_TILE_GRID, @@ -100,7 +101,11 @@ single_backend_collection_post_processing, ) from openeo_aggregator.partitionedjobs import PartitionedJob -from openeo_aggregator.partitionedjobs.crossbackend import CrossBackendSplitter +from openeo_aggregator.partitionedjobs.crossbackend import ( + CrossBackendJobSplitter, + DeepGraphSplitter, + LoadCollectionGraphSplitter, +) from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter from openeo_aggregator.partitionedjobs.tracking import ( PartitionedJobConnection, @@ -803,25 +808,34 @@ def create_job( if "process_graph" not in process: raise ProcessGraphMissingException() - # TODO: better, more generic/specific job_option(s)? - if job_options and (job_options.get(JOB_OPTION_SPLIT_STRATEGY) or job_options.get(JOB_OPTION_TILE_GRID)): - if job_options.get(JOB_OPTION_SPLIT_STRATEGY) == "crossbackend": - # TODO this is temporary feature flag to trigger "crossbackend" splitting - return self._create_crossbackend_job( - user_id=user_id, - process=process, - api_version=api_version, - metadata=metadata, - job_options=job_options, - ) - else: - return self._create_partitioned_job( - user_id=user_id, - process=process, - api_version=api_version, - metadata=metadata, - job_options=job_options, - ) + # Coverage of messy "split_strategy" job option + # Also see https://github.com/Open-EO/openeo-aggregator/issues/156 + # TODO: more generic and future proof handling of split strategy related options? + split_strategy = (job_options or {}).get(JOB_OPTION_SPLIT_STRATEGY) + tile_grid = (job_options or {}).get(JOB_OPTION_TILE_GRID) + crossbackend_mode = split_strategy == "crossbackend" or ( + isinstance(split_strategy, dict) and "crossbackend" in split_strategy + ) + # TODO: the legacy job option "tile_grid" is quite generic and not very explicit + # about being a job splitting approach. Can we deprecate this in a way? + spatial_split_mode = tile_grid or split_strategy == "flimsy" + + if crossbackend_mode: + return self._create_crossbackend_job( + user_id=user_id, + process=process, + api_version=api_version, + metadata=metadata, + job_options=job_options, + ) + elif spatial_split_mode: + return self._create_partitioned_job( + user_id=user_id, + process=process, + api_version=api_version, + metadata=metadata, + job_options=job_options, + ) else: return self._create_job_standard( user_id=user_id, @@ -936,14 +950,47 @@ def _create_crossbackend_job( if not self.partitioned_job_tracker: raise FeatureUnsupportedException(message="Partitioned job tracking is not supported") - def backend_for_collection(collection_id) -> str: - return self._catalog.get_backends_for_collection(cid=collection_id)[0] + split_strategy = (job_options or {}).get(JOB_OPTION_SPLIT_STRATEGY) + if split_strategy == "crossbackend": + # Legacy job option format + graph_split_method = CROSSBACKEND_GRAPH_SPLIT_METHOD.SIMPLE + elif isinstance(split_strategy, dict) and isinstance(split_strategy.get("crossbackend"), dict): + graph_split_method = split_strategy.get("crossbackend", {}).get( + "method", CROSSBACKEND_GRAPH_SPLIT_METHOD.SIMPLE + ) + else: + raise ValueError(f"Invalid split strategy {split_strategy!r}") - splitter = CrossBackendSplitter( - backend_for_collection=backend_for_collection, - # TODO: job option for `always_split` feature? - always_split=True, - ) + _log.info(f"_create_crossbackend_job: {graph_split_method=} from {split_strategy=}") + if graph_split_method == CROSSBACKEND_GRAPH_SPLIT_METHOD.SIMPLE: + + def backend_for_collection(collection_id) -> str: + return self._catalog.get_backends_for_collection(cid=collection_id)[0] + + graph_splitter = LoadCollectionGraphSplitter( + backend_for_collection=backend_for_collection, + # TODO: job option for `always_split` feature? + always_split=True, + ) + elif graph_split_method == CROSSBACKEND_GRAPH_SPLIT_METHOD.DEEP: + + def supporting_backends(node_id: str, node: dict) -> Union[List[str], None]: + # TODO: wider coverage checking process id availability + if node["process_id"] == "load_collection": + collection_id = node["arguments"]["id"] + return self._catalog.get_backends_for_collection(cid=collection_id) + + graph_splitter = DeepGraphSplitter( + supporting_backends=supporting_backends, + primary_backend=split_strategy.get("crossbackend", {}).get("primary_backend"), + # TODO: instead of this hardcoded deny-list, build it based on backend metadata inspection? + # TODO: make a config for this? + split_deny_list={"aggregate_spatial", "load_geojson", "load_url"}, + ) + else: + raise ValueError(f"Invalid graph split strategy {graph_split_method!r}") + + splitter = CrossBackendJobSplitter(graph_splitter=graph_splitter) pjob_id = self.partitioned_job_tracker.create_crossbackend_pjob( user_id=user_id, process=process, metadata=metadata, job_options=job_options, splitter=splitter diff --git a/src/openeo_aggregator/constants.py b/src/openeo_aggregator/constants.py index 68621bd..02a7a4a 100644 --- a/src/openeo_aggregator/constants.py +++ b/src/openeo_aggregator/constants.py @@ -4,3 +4,9 @@ # Experimental feature to force a certain upstream back-end through job options JOB_OPTION_FORCE_BACKEND = "_agg_force_backend" + + +class CROSSBACKEND_GRAPH_SPLIT_METHOD: + # Poor-man's StrEnum + SIMPLE = "simple" + DEEP = "deep" diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 9c97300..076b021 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -1,27 +1,68 @@ +from __future__ import annotations + +import abc import collections import copy +import dataclasses import datetime +import fractions +import functools import itertools import logging import time +import types from contextlib import nullcontext -from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple +from typing import ( + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Union, +) import openeo from openeo import BatchJob +from openeo.util import deep_get from openeo_driver.jobregistry import JOB_STATUS from openeo_aggregator.constants import JOB_OPTION_FORCE_BACKEND from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob from openeo_aggregator.partitionedjobs.splitting import AbstractJobSplitter -from openeo_aggregator.utils import FlatPG, PGWithMetadata, SkipIntermittentFailures +from openeo_aggregator.utils import ( + _UNSET, + FlatPG, + PGWithMetadata, + SkipIntermittentFailures, +) _log = logging.getLogger(__name__) _LOAD_RESULT_PLACEHOLDER = "_placeholder:" # Some type annotation aliases to make things more self-documenting +CollectionId = str SubGraphId = str +NodeId = str +BackendId = str +ProcessId = str + + +# Annotation for a function that maps node information (id and its node dict) +# to id(s) of the backend(s) that support it. +# Returning None means that support is unconstrained (any backend is assumed to support it). +SupportingBackendsMapper = Callable[[NodeId, dict], Union[BackendId, Iterable[BackendId], None]] + + +class GraphSplitException(Exception): + pass class GetReplacementCallable(Protocol): @@ -57,49 +98,57 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) } -class CrossBackendSplitter(AbstractJobSplitter): - """ - Split a process graph, to be executed across multiple back-ends, - based on availability of collections. +class _PGSplitSubGraph(NamedTuple): + """Container for result of ProcessGraphSplitterInterface.split""" + + split_node: NodeId + node_ids: Set[NodeId] + backend_id: BackendId - .. warning:: - this is experimental functionality +class _PGSplitResult(NamedTuple): + """Container for result of ProcessGraphSplitterInterface.split""" + + primary_node_ids: Set[NodeId] + primary_backend_id: BackendId + secondary_graphs: List[_PGSplitSubGraph] + + +class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta): + """ + Interface for process graph splitters: + given a process graph (flat graph representation), + produce a main graph and secondary graphs (as subsets of node ids) + and the backends they are supposed to run on. """ - def __init__(self, backend_for_collection: Callable[[str], str], always_split: bool = False): - """ - :param backend_for_collection: callable that determines backend id for given collection id - :param always_split: split all load_collections, also when on same backend + @abc.abstractmethod + def split(self, process_graph: FlatPG) -> _PGSplitResult: """ - # TODO: just handle this `backend_for_collection` callback with a regular method? - self.backend_for_collection = backend_for_collection - self._always_split = always_split + Split given process graph (flat graph representation) into sub graphs - def split_streaming( - self, - process_graph: FlatPG, - get_replacement: GetReplacementCallable = _default_get_replacement, - main_subgraph_id: SubGraphId = "main", - ) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]: + Returns primary graph data (node ids and backend id) + and secondary graphs data (list of tuples: split node id, subgraph node ids,backend id) """ - Split given process graph in sub-process graphs and return these as an iterator - in an order so that a subgraph comes after all subgraphs it depends on - (e.g. main "primary" graph comes last). + ... - The iterator approach allows working with a dynamic `get_replacement` implementation - that adapting to on previously produced subgraphs - (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately). - :return: tuple containing: - - subgraph id, recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}') - - SubJob - - dependencies as list of subgraph ids - """ +class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface): + """ + Simple process graph splitter that just splits off load_collection nodes. + """ + + # TODO: migrate backend_for_collection to SupportingBackendsMapper format? + def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False): + # TODO: also support not not having a backend_for_collection map? + self._backend_for_collection = backend_for_collection + self._always_split = always_split + + def split(self, process_graph: FlatPG) -> _PGSplitResult: # Extract necessary back-ends from `load_collection` usage backend_per_collection: Dict[str, str] = { - cid: self.backend_for_collection(cid) + cid: self._backend_for_collection(cid) for cid in ( node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection" ) @@ -112,45 +161,93 @@ def split_streaming( secondary_backends = {b for b in backend_usage if b != primary_backend} _log.info(f"Backend split: {primary_backend=} {secondary_backends=}") - primary_id = main_subgraph_id - primary_pg = {} primary_has_load_collection = False - primary_dependencies = [] - + primary_graph_node_ids = set() + secondary_graphs: List[_PGSplitSubGraph] = [] for node_id, node in process_graph.items(): if node["process_id"] == "load_collection": bid = backend_per_collection[node["arguments"]["id"]] if bid == primary_backend and (not self._always_split or not primary_has_load_collection): - # Add to primary pg - primary_pg[node_id] = node + primary_graph_node_ids.add(node_id) primary_has_load_collection = True else: - # New secondary pg - sub_id = f"{bid}:{node_id}" - sub_pg = { - node_id: node, - "sr1": { - # TODO: other/better choices for save_result format (e.g. based on backend support)? - "process_id": "save_result", - "arguments": { - "data": {"from_node": node_id}, - # TODO: particular format options? - # "format": "NetCDF", - "format": "GTiff", - }, - "result": True, - }, - } + secondary_graphs.append(_PGSplitSubGraph(split_node=node_id, node_ids={node_id}, backend_id=bid)) + else: + primary_graph_node_ids.add(node_id) - yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), []) + return _PGSplitResult( + primary_node_ids=primary_graph_node_ids, + primary_backend_id=primary_backend, + secondary_graphs=secondary_graphs, + ) - # Link secondary pg into primary pg - primary_pg.update(get_replacement(node_id=node_id, node=node, subgraph_id=sub_id)) - primary_dependencies.append(sub_id) - else: - primary_pg[node_id] = node - yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies) +class CrossBackendJobSplitter(AbstractJobSplitter): + """ + Split a process graph, to be executed across multiple back-ends, + based on availability of collections. + + .. warning:: + this is experimental functionality + + """ + + def __init__(self, graph_splitter: ProcessGraphSplitterInterface): + self._graph_splitter = graph_splitter + + def split_streaming( + self, + process_graph: FlatPG, + get_replacement: GetReplacementCallable = _default_get_replacement, + main_subgraph_id: SubGraphId = "main", + ) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]: + """ + Split given process graph in sub-process graphs and return these as an iterator + in an order so that a subgraph comes after all subgraphs it depends on + (e.g. main "primary" graph comes last). + + The iterator approach allows working with a dynamic `get_replacement` implementation + that can be adaptive to previously produced subgraphs + (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately). + + :return: Iterator of tuples containing: + - subgraph id, it's recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}') + - SubJob + - dependencies as list of subgraph ids + """ + + graph_split_result = self._graph_splitter.split(process_graph=process_graph) + + primary_pg = {k: process_graph[k] for k in graph_split_result.primary_node_ids} + primary_dependencies = [] + + for node_id, subgraph_node_ids, backend_id in graph_split_result.secondary_graphs: + # New secondary pg + sub_id = f"{backend_id}:{node_id}" + sub_pg = {k: v for k, v in process_graph.items() if k in subgraph_node_ids} + # Add new `save_result` node to the subgraphs + sub_pg["_agg_crossbackend_save_result"] = { + # TODO: other/better choices for save_result format (e.g. based on backend support, cube type)? + "process_id": "save_result", + "arguments": { + "data": {"from_node": node_id}, + # TODO: particular format options? + # "format": "NetCDF", + "format": "GTiff", + }, + "result": True, + } + yield (sub_id, SubJob(process_graph=sub_pg, backend_id=backend_id), []) + + # Link secondary pg into primary pg + primary_pg.update(get_replacement(node_id=node_id, node=process_graph[node_id], subgraph_id=sub_id)) + primary_dependencies.append(sub_id) + + yield ( + main_subgraph_id, + SubJob(process_graph=primary_pg, backend_id=graph_split_result.primary_backend_id), + primary_dependencies, + ) def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: """Split given process graph into a `PartitionedJob`""" @@ -351,3 +448,465 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai } for sid in subjobs.keys() } + + +def to_frozenset(value: Union[Iterable[str], str]) -> frozenset[str]: + """Coerce value to frozenset of strings""" + if isinstance(value, str): + value = [value] + return frozenset(value) + + +@dataclasses.dataclass(frozen=True, init=False, eq=True) +class _GVNode: + """ + Node in a _GraphViewer, with pointers to other nodes it depends on (needs data/input from) + and nodes to which it provides input to. + + This structure designed to be as immutable as possible (as far as Python allows) + to be (re)used in iterative/recursive graph handling algorithms, + without having to worry about accidentally propagating changed state to other parts of the graph. + """ + + # Node ids of other nodes this node depends on (aka parents) + depends_on: frozenset[NodeId] + # Node ids of other nodes that depend on this node (aka children) + flows_to: frozenset[NodeId] + + # Backend ids this node is marked to be supported on + # value None means it is unknown/unconstrained for this node + backend_candidates: Union[frozenset[BackendId], None] + + def __init__( + self, + *, + depends_on: Union[Iterable[NodeId], NodeId, None] = None, + flows_to: Union[Iterable[NodeId], NodeId, None] = None, + backend_candidates: Union[Iterable[BackendId], BackendId, None] = None, + ): + # TODO: type coercion in __init__ of frozen dataclasses is bit ugly. Use attrs with field converters instead? + super().__init__() + object.__setattr__(self, "depends_on", to_frozenset(depends_on or [])) + object.__setattr__(self, "flows_to", to_frozenset(flows_to or [])) + backend_candidates = to_frozenset(backend_candidates) if backend_candidates is not None else None + object.__setattr__(self, "backend_candidates", backend_candidates) + + def __repr__(self): + # Somewhat cryptic, but compact representation of node attributes + depends_on = (" <" + ",".join(sorted(self.depends_on))) if self.depends_on else "" + flows_to = (" >" + ",".join(sorted(self.flows_to))) if self.flows_to else "" + backends = (" @" + ",".join(sorted(self.backend_candidates))) if self.backend_candidates else "" + return f"[{type(self).__name__}{depends_on}{flows_to}{backends}]" + + +class _GraphViewer: + """ + Internal utility to have read-only view on the topological structure of a proces graph + and track the flow of backend support. + + """ + + def __init__(self, node_map: dict[NodeId, _GVNode]): + self._check_consistency(node_map=node_map) + # Work with a read-only proxy to prevent accidental changes + self._graph: Mapping[NodeId, _GVNode] = types.MappingProxyType(node_map.copy()) + + @staticmethod + def _check_consistency(node_map: dict[NodeId, _GVNode]): + """Check (link) consistency of given node map""" + key_ids = set(node_map.keys()) + linked_ids = set(k for n in node_map.values() for k in n.depends_on.union(n.flows_to)) + unknown = linked_ids.difference(key_ids) + if unknown: + raise GraphSplitException(f"Inconsistent node map: {key_ids=} != {linked_ids=}: {unknown=}") + bad_links = set() + for node_id, node in node_map.items(): + bad_links.update((other, node_id) for other in node.depends_on if node_id not in node_map[other].flows_to) + bad_links.update((node_id, other) for other in node.flows_to if node_id not in node_map[other].depends_on) + if bad_links: + raise GraphSplitException(f"Inconsistent node map: {bad_links=}") + + def __repr__(self): + return f"<{type(self).__name__}({self._graph})>" + + @classmethod + def from_flat_graph(cls, flat_graph: FlatPG, supporting_backends: SupportingBackendsMapper = (lambda n, d: None)): + """ + Build _GraphViewer from a flat process graph representation + """ + _log.debug(f"_GraphViewer.from_flat_graph: {flat_graph.keys()=}") + # Extract dependency links between nodes + depends_on = collections.defaultdict(list) + flows_to = collections.defaultdict(list) + for node_id, node in flat_graph.items(): + for arg_value in node.get("arguments", {}).values(): + if isinstance(arg_value, dict) and list(arg_value.keys()) == ["from_node"]: + from_node = arg_value["from_node"] + depends_on[node_id].append(from_node) + flows_to[from_node].append(node_id) + graph = { + node_id: _GVNode( + depends_on=depends_on.get(node_id, []), + flows_to=flows_to.get(node_id, []), + backend_candidates=supporting_backends(node_id, node), + ) + for node_id, node in flat_graph.items() + } + return cls(node_map=graph) + + @classmethod + def from_edges( + cls, + edges: Iterable[Tuple[NodeId, NodeId]], + supporting_backends_mapper: SupportingBackendsMapper = (lambda n, d: None), + ): + """ + Simple factory to build graph from parent-child tuples for testing purposes + """ + depends_on = collections.defaultdict(list) + flows_to = collections.defaultdict(list) + for parent, child in edges: + depends_on[child].append(parent) + flows_to[parent].append(child) + + graph = { + node_id: _GVNode( + depends_on=depends_on.get(node_id, []), + flows_to=flows_to.get(node_id, []), + backend_candidates=supporting_backends_mapper(node_id, {}), + ) + for node_id in set(depends_on.keys()).union(flows_to.keys()) + } + return cls(node_map=graph) + + def node(self, node_id: NodeId) -> _GVNode: + if node_id not in self._graph: + raise GraphSplitException(f"Invalid node id {node_id!r}.") + return self._graph[node_id] + + def iter_nodes(self) -> Iterator[Tuple[NodeId, _GVNode]]: + """Iterate through node_id-node pairs""" + yield from self._graph.items() + + def _walk( + self, + seeds: Iterable[NodeId], + next_nodes: Callable[[NodeId], Iterable[NodeId]], + include_seeds: bool = True, + auto_sort: bool = True, + ) -> Iterator[NodeId]: + """ + Walk the graph nodes starting from given seed nodes, + taking steps as defined by `next_nodes` function. + Walks breadth first and each node is only visited once. + + :param include_seeds: whether to include the seed nodes in the walk + :param auto_sort: visit "next" nodes of a given node lexicographically sorted + to make the walk deterministic. + """ + # TODO: option to walk depth first instead of breadth first? + if auto_sort: + # Automatically sort next nodes to make walk more deterministic + prepare = sorted + else: + prepare = lambda x: x + + if include_seeds: + visited = set() + to_visit = list(prepare(seeds)) + else: + visited = set(seeds) + to_visit = [n for s in seeds for n in prepare(next_nodes(s))] + + while to_visit: + node_id = to_visit.pop(0) + if node_id in visited: + continue + yield node_id + visited.add(node_id) + to_visit.extend(prepare(set(next_nodes(node_id)).difference(visited))) + + def walk_upstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]: + """ + Walk upstream nodes (along `depends_on` link) starting from given seed nodes. + Optionally include seeds or not, and walk breadth first. + """ + return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).depends_on, include_seeds=include_seeds) + + def walk_downstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]: + """ + Walk downstream nodes (along `flows_to` link) starting from given seed nodes. + Optionally include seeds or not, and walk breadth first. + """ + return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).flows_to, include_seeds=include_seeds) + + def get_backend_candidates_for_node(self, node_id: NodeId) -> Union[frozenset[BackendId], None]: + """Determine backend candidates for given node id""" + # TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated) + if self.node(node_id).backend_candidates is not None: + # Node has explicit backend candidates listed + return self.node(node_id).backend_candidates + elif self.node(node_id).depends_on: + # Backend support is unset: determine it (as intersection) from upstream nodes + return self.get_backend_candidates_for_node_set(self.node(node_id).depends_on) + else: + return None + + def get_backend_candidates_for_node_set(self, node_ids: Iterable[NodeId]) -> Union[frozenset[BackendId], None]: + """ + Determine backend candidates for a set of nodes + """ + candidates = set(self.get_backend_candidates_for_node(n) for n in node_ids) + if candidates == {None}: + return None + candidates.discard(None) + return functools.reduce(lambda a, b: a.intersection(b), candidates) + + def find_forsaken_nodes(self) -> Set[NodeId]: + """ + Find nodes that have no backend candidates to process them + """ + return set( + node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates_for_node(node_id) == set() + ) + + def find_articulation_points(self) -> Set[NodeId]: + """ + Find articulation points (cut vertices) in the directed graph: + nodes that when removed would split the graph into multiple sub-graphs. + + Note that, unlike in traditional graph theory, the search also includes leaf nodes + (e.g. nodes with no parents), as in this context of openEO graph splitting, + when we "cut" a node, we replace it with two disconnected new nodes + (one connecting to the original parents and one connecting to the original children). + """ + # Approach: label the start nodes (e.g. load_collection) with their id and weight 1. + # Propagate these labels along the depends-on links, but split/sum the weight according + # to the number of children/parents. + # At the end: the articulation points are the nodes where all flows have weight 1. + + # Mapping: node_id -> start_node_id -> flow_weight + flow_weights: Dict[NodeId, Dict[NodeId, fractions.Fraction]] = {} + + # Initialize at the pure input nodes (nodes with no upstream dependencies) + for node_id, node in self.iter_nodes(): + if not node.depends_on: + flow_weights[node_id] = {node_id: fractions.Fraction(1, 1)} + + # Propagate flow weights using recursion + caching + def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]: + nonlocal flow_weights + if node_id not in flow_weights: + flow_weights[node_id] = {} + # Calculate from upstream nodes + for upstream in self.node(node_id).depends_on: + for start_node_id, weight in get_flow_weights(upstream).items(): + flow_weights[node_id].setdefault(start_node_id, fractions.Fraction(0, 1)) + flow_weights[node_id][start_node_id] += weight / len(self.node(upstream).flows_to) + return flow_weights[node_id] + + for node_id, node in self.iter_nodes(): + get_flow_weights(node_id) + + # Select articulation points: nodes where all flows have weight 1 + return set(node_id for node_id, flows in flow_weights.items() if all(w == 1 for w in flows.values())) + + def split_at(self, split_node_id: NodeId) -> Tuple[_GraphViewer, _GraphViewer]: + """ + Split graph at given node id (must be articulation point), + creating two new graph viewers, containing original nodes and adaptation of the split node. + + :return: two _GraphViewer objects: the upstream subgraph and the downstream subgraph + """ + split_node = self.node(split_node_id) + + # Walk the graph, upstream from the split node + def next_nodes(node_id: NodeId) -> Iterable[NodeId]: + node = self.node(node_id) + if node_id == split_node_id: + return node.depends_on + else: + return node.depends_on.union(node.flows_to) + + up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes)) + + if split_node.flows_to.intersection(up_node_ids): + raise GraphSplitException(f"Graph can not be split at {split_node_id!r}: not an articulation point.") + + up_graph = {n: self.node(n) for n in up_node_ids} + # Replacement of original split node: no `flows_to` links + up_graph[split_node_id] = _GVNode( + depends_on=split_node.depends_on, + backend_candidates=split_node.backend_candidates, + ) + up = _GraphViewer(node_map=up_graph) + + down_graph = {n: node for n, node in self.iter_nodes() if n not in up_node_ids} + # Replacement of original split node: no `depends_on` links + # and perhaps more importantly: do not copy over the original `backend_candidates`` + down_graph[split_node_id] = _GVNode( + flows_to=split_node.flows_to, + backend_candidates=None, + ) + down = _GraphViewer(node_map=down_graph) + + return up, down + + def produce_split_locations( + self, + limit: int = 20, + allow_split: Callable[[NodeId], bool] = lambda n: True, + ) -> Iterator[List[NodeId]]: + """ + Produce disjoint subgraphs that can be processed independently. + + :param limit: maximum number of split locations to produce + :param allow_split: predicate to determine if a node can be split on (e.g. to deny splitting on certain nodes) + + :return: iterator of node listings. + Each node listing encodes a graph split (nodes ids where to split). + A node listing is ordered with the following in mind: + - the first node id does a first split in a downstream and upstream part. + The upstream part can be handled by a single backend. + The downstream part is not necessarily covered by a single backend, + in which case one or more additional splits will be necessary. + - the second node id does a second split of the downstream part of + the previous split. + - etc + """ + # TODO: allow_split: possible to make this backend-dependent in some way? + + # Find nodes that have empty set of backend_candidates + forsaken_nodes = self.find_forsaken_nodes() + + if forsaken_nodes: + # Sort forsaken nodes (based on forsaken parent count), to start higher up the graph + forsaken_nodes = sorted( + forsaken_nodes, key=lambda n: (sum(p in forsaken_nodes for p in self.node(n).depends_on), n) + ) + _log.debug(f"_GraphViewer.produce_split_locations: {forsaken_nodes=}") + + # Collect nodes where we could split the graph in disjoint subgraphs + articulation_points: Set[NodeId] = set(self.find_articulation_points()) + _log.debug(f"_GraphViewer.produce_split_locations: {articulation_points=}") + + # Walk upstream from forsaken nodes to find articulation points, where we can cut + split_options = [ + n + for n in self.walk_upstream_nodes(seeds=forsaken_nodes, include_seeds=False) + if n in articulation_points and allow_split(n) + ] + _log.debug(f"_GraphViewer.produce_split_locations: {split_options=}") + if not split_options: + raise GraphSplitException("No split options found.") + # TODO: Do we really need a limit? Or is there a practical scalability risk to list all possibilities? + assert limit > 0 + for split_node_id in split_options[:limit]: + _log.debug(f"_GraphViewer.produce_split_locations: splitting at {split_node_id=}") + up, down = self.split_at(split_node_id) + # The upstream part should now be handled by a single backend + assert not up.find_forsaken_nodes() + # Recursively split downstream part if necessary + if down.find_forsaken_nodes(): + down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1), allow_split=allow_split)) + else: + down_splits = [[]] + + for down_split in down_splits: + yield [split_node_id] + down_split + + else: + # All nodes can be handled as is, no need to split + yield [] + + def split_at_multiple(self, split_nodes: List[NodeId]) -> Dict[Union[NodeId, None], _GraphViewer]: + """ + Split the graph viewer at multiple nodes in the order as provided. + Each split produces an upstream and downstream graph, + the downstream graph is used for the next split, + so the split nodes should be ordered as such. + + Returns dictionary with: + - key: split node_ids or None for the final downstream graph + - value: corresponding sub graph viewers as values. + """ + result = {} + graph_to_split = self + for split_node_id in split_nodes: + up, down = graph_to_split.split_at(split_node_id=split_node_id) + result[split_node_id] = up + graph_to_split = down + result[None] = graph_to_split + return result + + +class DeepGraphSplitter(ProcessGraphSplitterInterface): + """ + More advanced graph splitting (compared to just splitting off `load_collection` nodes) + + :param split_deny_list: list of process ids that should not be split on + """ + + def __init__( + self, + supporting_backends: SupportingBackendsMapper, + primary_backend: Optional[BackendId] = None, + split_deny_list: Iterable[ProcessId] = (), + ): + self._supporting_backends_mapper = supporting_backends + self._primary_backend = primary_backend + # TODO also support other deny mechanisms, e.g. callable instead of a deny list? + self._split_deny_list = set(split_deny_list) + + def _pick_backend(self, backend_candidates: Union[frozenset[BackendId], None]) -> BackendId: + if backend_candidates is None: + if self._primary_backend: + return self._primary_backend + else: + raise GraphSplitException("DeepGraphSplitter._pick_backend: No backend candidates.") + else: + # TODO: better backend selection mechanism + return sorted(backend_candidates)[0] + + def split(self, process_graph: FlatPG) -> _PGSplitResult: + graph = _GraphViewer.from_flat_graph( + flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper + ) + + def allow_split(node_id: NodeId) -> bool: + process_id = deep_get(process_graph, node_id, "process_id", default=None) + return process_id not in self._split_deny_list + + for split_nodes in graph.produce_split_locations(allow_split=allow_split): + _log.debug(f"DeepGraphSplitter.split: evaluating split nodes: {split_nodes=}") + + split_views = graph.split_at_multiple(split_nodes=split_nodes) + + # Extract nodes and backend ids for each subgraph + subgraph_node_ids = {k: set(n for n, _ in v.iter_nodes()) for k, v in split_views.items()} + subgraph_backend_ids = { + k: self._pick_backend(backend_candidates=v.get_backend_candidates_for_node_set(subgraph_node_ids[k])) + for k, v in split_views.items() + } + _log.debug(f"DeepGraphSplitter.split: {subgraph_node_ids=} {subgraph_backend_ids=}") + + # Handle primary graph + split_views.pop(None) + primary_node_ids = subgraph_node_ids.pop(None) + primary_backend_id = subgraph_backend_ids.pop(None) + + # Handle secondary graphs + secondary_graphs = [ + _PGSplitSubGraph(split_node=k, node_ids=subgraph_node_ids[k], backend_id=subgraph_backend_ids[k]) + for k in split_views.keys() + ] + + if self._primary_backend is None or primary_backend_id == self._primary_backend: + _log.debug(f"DeepGraphSplitter.split: current split matches constraints") + return _PGSplitResult( + primary_node_ids=primary_node_ids, + primary_backend_id=primary_backend_id, + secondary_graphs=secondary_graphs, + ) + + raise GraphSplitException("DeepGraphSplitter.split: No matching split found.") diff --git a/src/openeo_aggregator/partitionedjobs/tracking.py b/src/openeo_aggregator/partitionedjobs/tracking.py index a34f36e..26ce77a 100644 --- a/src/openeo_aggregator/partitionedjobs/tracking.py +++ b/src/openeo_aggregator/partitionedjobs/tracking.py @@ -24,7 +24,7 @@ SubJob, ) from openeo_aggregator.partitionedjobs.crossbackend import ( - CrossBackendSplitter, + CrossBackendJobSplitter, SubGraphId, ) from openeo_aggregator.partitionedjobs.splitting import TileGridSplitter @@ -71,7 +71,7 @@ def create_crossbackend_pjob( process: PGWithMetadata, metadata: dict, job_options: Optional[dict] = None, - splitter: CrossBackendSplitter, + splitter: CrossBackendJobSplitter, ) -> str: """ crossbackend partitioned job creation is different from original partitioned diff --git a/src/openeo_aggregator/testing.py b/src/openeo_aggregator/testing.py index 6cf60d8..9f044b8 100644 --- a/src/openeo_aggregator/testing.py +++ b/src/openeo_aggregator/testing.py @@ -66,13 +66,13 @@ def exists(self, path): def get(self, path): self._assert_open() if path not in self.data: - raise kazoo.exceptions.NoNodeError() + raise kazoo.exceptions.NoNodeError(path) return self.data[path] def get_children(self, path): self._assert_open() if path not in self.data: - raise kazoo.exceptions.NoNodeError() + raise kazoo.exceptions.NoNodeError(path) parent = path.split("/") return [p.split("/")[-1] for p in self.data if p.split("/")[:-1] == parent] @@ -103,6 +103,8 @@ def approx_now(abs=10): class ApproxStr: """Pytest helper in style of `pytest.approx`, but for string checking, based on prefix, body and or suffix""" + # TODO: port to dirty_equals + def __init__( self, prefix: Optional[str] = None, diff --git a/tests/partitionedjobs/test_api.py b/tests/partitionedjobs/test_api.py index aaf1946..956cae6 100644 --- a/tests/partitionedjobs/test_api.py +++ b/tests/partitionedjobs/test_api.py @@ -52,7 +52,7 @@ def __init__(self, date: str): def dummy1(backend1, requests_mock) -> DummyBackend: # TODO: rename this fixture to dummy_backed1 for clarity dummy = DummyBackend(requests_mock=requests_mock, backend_url=backend1, job_id_template="1-jb-{i}") - dummy.setup_basic_requests_mocks() + dummy.setup_basic_requests_mocks(collections=["S1", "S2"]) dummy.register_user(bearer_token=TEST_USER_BEARER_TOKEN, user_id=TEST_USER) return dummy @@ -61,7 +61,7 @@ def dummy1(backend1, requests_mock) -> DummyBackend: def dummy2(backend2, requests_mock) -> DummyBackend: # TODO: rename this fixture to dummy_backed2 for clarity dummy = DummyBackend(requests_mock=requests_mock, backend_url=backend2, job_id_template="2-jb-{i}") - dummy.setup_basic_requests_mocks(collections=["S22"]) + dummy.setup_basic_requests_mocks(collections=["T11", "T22"]) dummy.register_user(bearer_token=TEST_USER_BEARER_TOKEN, user_id=TEST_USER) return dummy @@ -685,17 +685,28 @@ def _partitioned_job_tracking(self, zk_client): yield @now.mock - def test_create_job_simple(self, flask_app, api100, zk_db, dummy1): + @pytest.mark.parametrize( + "split_strategy", + [ + "crossbackend", + {"crossbackend": {"method": "simple"}}, + {"crossbackend": {"method": "deep"}}, + ], + ) + def test_create_job_simple(self, flask_app, api100, zk_db, dummy1, split_strategy): """Handling of single "load_collection" process graph""" api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) - pg = {"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}, "result": True}} + pg = { + # lc1 (that's it, that's the graph) + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}, "result": True} + } res = api100.post( "/jobs", json={ "process": {"process_graph": pg}, - "job_options": {"split_strategy": "crossbackend"}, + "job_options": {"split_strategy": split_strategy}, }, ).assert_status_code(201) @@ -719,7 +730,7 @@ def test_create_job_simple(self, flask_app, api100, zk_db, dummy1): "created": self.now.epoch, "process": {"process_graph": pg}, "metadata": {"log_level": "info"}, - "job_options": {"split_strategy": "crossbackend"}, + "job_options": {"split_strategy": split_strategy}, "result_jobs": ["main"], } @@ -753,12 +764,23 @@ def test_create_job_simple(self, flask_app, api100, zk_db, dummy1): assert pg == {"lc1": {"arguments": {"id": "S2"}, "process_id": "load_collection", "result": True}} @now.mock - def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock): + @pytest.mark.parametrize( + "split_strategy", + [ + "crossbackend", + {"crossbackend": {"method": "simple"}}, + {"crossbackend": {"method": "deep", "primary_backend": "b1"}}, + ], + ) + def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, dummy2, requests_mock, split_strategy): api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) pg = { + # lc1 lc2 + # \ / + # merge "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, - "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}}, "merge": { "process_id": "merge_cubes", "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, @@ -767,13 +789,16 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) } requests_mock.get( - "https://b1.test/v1/jobs/1-jb-0/results?partial=true", - json={"links": [{"rel": "canonical", "href": "https://data.b1.test/123abc"}]}, + "https://b2.test/v1/jobs/2-jb-0/results?partial=true", + json={"links": [{"rel": "canonical", "href": "https://data.b2.test/123abc"}]}, ) res = api100.post( "/jobs", - json={"process": {"process_graph": pg}, "job_options": {"split_strategy": "crossbackend"}}, + json={ + "process": {"process_graph": pg}, + "job_options": {"split_strategy": split_strategy}, + }, ).assert_status_code(201) pjob_id = "pj-20220119-123456" @@ -796,7 +821,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) "created": self.now.epoch, "process": {"process_graph": pg}, "metadata": {"log_level": "info"}, - "job_options": {"split_strategy": "crossbackend"}, + "job_options": {"split_strategy": split_strategy}, "result_jobs": ["main"], } @@ -810,17 +835,17 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) # Inspect stored subjob metadata subjobs = zk_db.list_subjobs(user_id=TEST_USER, pjob_id=pjob_id) assert subjobs == { - "b1:lc2": { - "backend_id": "b1", + "b2:lc2": { + "backend_id": "b2", "process_graph": { - "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, - "sr1": { + "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}}, + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, }, }, - "title": "Partitioned job pjob_id='pj-20220119-123456' sjob_id='b1:lc2'", + "title": "Partitioned job pjob_id='pj-20220119-123456' sjob_id='b2:lc2'", }, "main": { "backend_id": "b1", @@ -828,7 +853,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "lc2": { "process_id": "load_stac", - "arguments": {"url": "https://data.b1.test/123abc"}, + "arguments": {"url": "https://data.b2.test/123abc"}, }, "merge": { "process_id": "merge_cubes", @@ -841,7 +866,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) } sjob_id = "main" - expected_job_id = "1-jb-1" + expected_job_id = "1-jb-0" assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == { "status": "created", "timestamp": self.now.epoch, @@ -853,7 +878,7 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "lc2": { "process_id": "load_stac", - "arguments": {"url": "https://data.b1.test/123abc"}, + "arguments": {"url": "https://data.b2.test/123abc"}, }, "merge": { "process_id": "merge_cubes", @@ -862,18 +887,18 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) }, } - sjob_id = "b1:lc2" - expected_job_id = "1-jb-0" + sjob_id = "b2:lc2" + expected_job_id = "2-jb-0" assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == { "status": "created", "timestamp": self.now.epoch, "message": None, } assert zk_db.get_backend_job_id(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == expected_job_id - assert dummy1.get_job_status(TEST_USER, expected_job_id) == "created" - assert dummy1.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == { - "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, - "sr1": { + assert dummy2.get_job_status(TEST_USER, expected_job_id) == "created" + assert dummy2.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == { + "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}}, + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, @@ -881,13 +906,24 @@ def test_create_job_basic(self, flask_app, api100, zk_db, dummy1, requests_mock) } @now.mock - def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_mock): + @pytest.mark.parametrize( + "split_strategy", + [ + "crossbackend", + {"crossbackend": {"method": "simple"}}, + {"crossbackend": {"method": "deep", "primary_backend": "b1"}}, + ], + ) + def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, dummy2, requests_mock, split_strategy): """Run the jobs and get results""" api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) pg = { + # lc1 lc2 + # \ / + # merge "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, - "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}}, "merge": { "process_id": "merge_cubes", "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, @@ -896,15 +932,15 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_ } requests_mock.get( - "https://b1.test/v1/jobs/1-jb-0/results?partial=true", - json={"links": [{"rel": "canonical", "href": "https://data.b1.test/123abc"}]}, + "https://b2.test/v1/jobs/2-jb-0/results?partial=true", + json={"links": [{"rel": "canonical", "href": "https://data.b2.test/123abc"}]}, ) res = api100.post( "/jobs", json={ "process": {"process_graph": pg}, - "job_options": {"split_strategy": "crossbackend"}, + "job_options": {"split_strategy": split_strategy}, }, ).assert_status_code(201) @@ -923,21 +959,21 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_ # start job api100.post(f"/jobs/{expected_job_id}/results").assert_status_code(202) - dummy1.set_job_status(TEST_USER, "1-jb-0", status="running") - dummy1.set_job_status(TEST_USER, "1-jb-1", status="queued") + dummy2.set_job_status(TEST_USER, "2-jb-0", status="running") + dummy1.set_job_status(TEST_USER, "1-jb-0", status="queued") res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200) assert res.json == DictSubSet({"id": expected_job_id, "status": "running", "progress": 0}) # First job is ready - dummy1.set_job_status(TEST_USER, "1-jb-0", status="finished") - dummy1.setup_assets(job_id=f"1-jb-0", assets=["1-jb-0-result.tif"]) - dummy1.set_job_status(TEST_USER, "1-jb-1", status="running") + dummy2.set_job_status(TEST_USER, "2-jb-0", status="finished") + dummy2.setup_assets(job_id=f"2-jb-0", assets=["2-jb-0-result.tif"]) + dummy1.set_job_status(TEST_USER, "1-jb-0", status="running") res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200) assert res.json == DictSubSet({"id": expected_job_id, "status": "running", "progress": 50}) # Main job is ready too - dummy1.set_job_status(TEST_USER, "1-jb-1", status="finished") - dummy1.setup_assets(job_id=f"1-jb-1", assets=["1-jb-1-result.tif"]) + dummy1.set_job_status(TEST_USER, "1-jb-0", status="finished") + dummy1.setup_assets(job_id=f"1-jb-0", assets=["1-jb-0-result.tif"]) res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200) assert res.json == DictSubSet({"id": expected_job_id, "status": "finished", "progress": 100}) @@ -947,10 +983,10 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_ { "id": expected_job_id, "assets": { - "main-1-jb-1-result.tif": { - "href": "https://b1.test/v1/jobs/1-jb-1/results/1-jb-1-result.tif", + "main-1-jb-0-result.tif": { + "href": "https://b1.test/v1/jobs/1-jb-0/results/1-jb-0-result.tif", "roles": ["data"], - "title": "main-1-jb-1-result.tif", + "title": "main-1-jb-0-result.tif", "type": "application/octet-stream", }, }, @@ -958,14 +994,25 @@ def test_start_and_job_results(self, flask_app, api100, zk_db, dummy1, requests_ ) @now.mock - def test_failing_create(self, flask_app, api100, zk_db, dummy1): + @pytest.mark.parametrize( + "split_strategy", + [ + "crossbackend", + {"crossbackend": {"method": "simple"}}, + {"crossbackend": {"method": "deep", "primary_backend": "b1"}}, + ], + ) + def test_failing_create(self, flask_app, api100, zk_db, dummy1, dummy2, split_strategy): """Run what happens when creation of sub batch job fails on upstream backend""" api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) - dummy1.fail_create_job = True + dummy2.fail_create_job = True pg = { + # lc1 lc2 + # \ / + # merge "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, - "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}}, "merge": { "process_id": "merge_cubes", "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, @@ -977,7 +1024,7 @@ def test_failing_create(self, flask_app, api100, zk_db, dummy1): "/jobs", json={ "process": {"process_graph": pg}, - "job_options": {"split_strategy": "crossbackend"}, + "job_options": {"split_strategy": split_strategy}, }, ).assert_status_code(201) @@ -993,3 +1040,138 @@ def test_failing_create(self, flask_app, api100, zk_db, dummy1): "created": self.now.rfc3339, "progress": 0, } + + @now.mock + def test_create_job_deep_basic(self, flask_app, api100, zk_db, dummy1, dummy2, requests_mock): + api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) + + pg = { + # lc1 lc2 + # | | + # bands1 temporal2 + # \ / + # merge + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "T22"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}}}, + "temporal2": {"process_id": "filter_temporal", "arguments": {"data": {"from_node": "lc2"}}}, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + "result": True, + }, + } + + requests_mock.get( + "https://b2.test/v1/jobs/2-jb-0/results?partial=true", + json={"links": [{"rel": "canonical", "href": "https://data.b2.test/123abc"}]}, + ) + + split_strategy = {"crossbackend": {"method": "deep", "primary_backend": "b1"}} + res = api100.post( + "/jobs", + json={ + "process": {"process_graph": pg}, + "job_options": {"split_strategy": split_strategy}, + }, + ).assert_status_code(201) + + pjob_id = "pj-20220119-123456" + expected_job_id = f"agg-{pjob_id}" + assert res.headers["Location"] == f"http://oeoa.test/openeo/1.0.0/jobs/{expected_job_id}" + assert res.headers["OpenEO-Identifier"] == expected_job_id + + res = api100.get(f"/jobs/{expected_job_id}").assert_status_code(200) + assert res.json == { + "id": expected_job_id, + "process": {"process_graph": pg}, + "status": "created", + "created": self.now.rfc3339, + "progress": 0, + } + + # Inspect stored parent job metadata + assert zk_db.get_pjob_metadata(user_id=TEST_USER, pjob_id=pjob_id) == { + "user_id": TEST_USER, + "created": self.now.epoch, + "process": {"process_graph": pg}, + "metadata": {"log_level": "info"}, + "job_options": {"split_strategy": split_strategy}, + "result_jobs": ["main"], + } + + assert zk_db.get_pjob_status(user_id=TEST_USER, pjob_id=pjob_id) == { + "status": "created", + "message": approx_str_contains("{'created': 2}"), + "timestamp": pytest.approx(self.now.epoch, abs=5), + "progress": 0, + } + + # Inspect stored subjob metadata + subjobs = zk_db.list_subjobs(user_id=TEST_USER, pjob_id=pjob_id) + assert subjobs == { + "b2:temporal2": { + "backend_id": "b2", + "process_graph": { + "lc2": {"arguments": {"id": "T22"}, "process_id": "load_collection"}, + "temporal2": {"arguments": {"data": {"from_node": "lc2"}}, "process_id": "filter_temporal"}, + "_agg_crossbackend_save_result": { + "arguments": {"data": {"from_node": "temporal2"}, "format": "GTiff"}, + "process_id": "save_result", + "result": True, + }, + }, + "title": "Partitioned job pjob_id='pj-20220119-123456' sjob_id='b2:temporal2'", + }, + "main": { + "backend_id": "b1", + "process_graph": { + "lc1": {"arguments": {"id": "S2"}, "process_id": "load_collection"}, + "bands1": {"arguments": {"data": {"from_node": "lc1"}}, "process_id": "filter_bands"}, + "temporal2": {"arguments": {"url": "https://data.b2.test/123abc"}, "process_id": "load_stac"}, + "merge": { + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + "process_id": "merge_cubes", + "result": True, + }, + }, + "title": "Partitioned job pjob_id='pj-20220119-123456' " "sjob_id='main'", + }, + } + sjob_id = "main" + expected_job_id = "1-jb-0" + assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == { + "status": "created", + "timestamp": self.now.epoch, + "message": None, + } + assert zk_db.get_backend_job_id(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == expected_job_id + assert dummy1.get_job_status(TEST_USER, expected_job_id) == "created" + assert dummy1.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == { + "lc1": {"arguments": {"id": "S2"}, "process_id": "load_collection"}, + "bands1": {"arguments": {"data": {"from_node": "lc1"}}, "process_id": "filter_bands"}, + "temporal2": {"arguments": {"url": "https://data.b2.test/123abc"}, "process_id": "load_stac"}, + "merge": { + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + "process_id": "merge_cubes", + "result": True, + }, + } + sjob_id = "b2:temporal2" + expected_job_id = "2-jb-0" + assert zk_db.get_sjob_status(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == { + "status": "created", + "timestamp": self.now.epoch, + "message": None, + } + assert zk_db.get_backend_job_id(user_id=TEST_USER, pjob_id=pjob_id, sjob_id=sjob_id) == expected_job_id + assert dummy2.get_job_status(TEST_USER, expected_job_id) == "created" + assert dummy2.get_job_data(TEST_USER, expected_job_id).create["process"]["process_graph"] == { + "lc2": {"arguments": {"id": "T22"}, "process_id": "load_collection"}, + "temporal2": {"arguments": {"data": {"from_node": "lc2"}}, "process_id": "filter_temporal"}, + "_agg_crossbackend_save_result": { + "arguments": {"data": {"from_node": "temporal2"}, "format": "GTiff"}, + "process_id": "save_result", + "result": True, + }, + } diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index b4b6239..adf0e65 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -13,8 +13,16 @@ from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob from openeo_aggregator.partitionedjobs.crossbackend import ( - CrossBackendSplitter, + CrossBackendJobSplitter, + DeepGraphSplitter, + GraphSplitException, + LoadCollectionGraphSplitter, SubGraphId, + SupportingBackendsMapper, + _GraphViewer, + _GVNode, + _PGSplitResult, + _PGSplitSubGraph, run_partitioned_job, ) @@ -22,7 +30,9 @@ class TestCrossBackendSplitter: def test_split_simple(self): process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo") + splitter = CrossBackendJobSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo") + ) res = splitter.split({"process_graph": process_graph}) assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)} @@ -30,7 +40,9 @@ def test_split_simple(self): def test_split_streaming_simple(self): process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo") + splitter = CrossBackendJobSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo") + ) res = splitter.split_streaming(process_graph) assert isinstance(res, types.GeneratorType) assert list(res) == [("main", SubJob(process_graph, backend_id=None), [])] @@ -46,13 +58,15 @@ def test_split_basic(self): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendJobSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) res = splitter.split({"process_graph": process_graph}) assert res.subjobs == { @@ -73,7 +87,7 @@ def test_split_basic(self): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -87,7 +101,7 @@ def test_split_basic(self): "process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, @@ -109,13 +123,15 @@ def test_split_streaming_basic(self): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendJobSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) result = splitter.split_streaming(process_graph) assert isinstance(result, types.GeneratorType) @@ -128,7 +144,7 @@ def test_split_streaming_basic(self): "process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, @@ -148,7 +164,7 @@ def test_split_streaming_basic(self): "process_id": "merge_cubes", "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -175,7 +191,9 @@ def test_split_streaming_get_replacement(self): "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendJobSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) batch_jobs = {} @@ -200,7 +218,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict: SubJob( process_graph={ "lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}}, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, @@ -215,7 +233,7 @@ def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict: SubJob( process_graph={ "lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}}, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc3"}, "format": "GTiff"}, "result": True, @@ -365,13 +383,15 @@ def test_basic(self, aggregator: _FakeAggregator): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendJobSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) pjob: PartitionedJob = splitter.split({"process_graph": process_graph}) connection = openeo.Connection(aggregator.url) @@ -400,7 +420,7 @@ def test_basic(self, aggregator: _FakeAggregator): "cube2": {"from_node": "lc2"}, }, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"}, "result": True, @@ -411,9 +431,792 @@ def test_basic(self, aggregator: _FakeAggregator): "process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}, }, - "sr1": { + "_agg_crossbackend_save_result": { "process_id": "save_result", "arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"}, "result": True, }, } + + +class TestGVNode: + def test_defaults(self): + node = _GVNode() + assert isinstance(node.depends_on, frozenset) + assert node.depends_on == frozenset() + assert isinstance(node.flows_to, frozenset) + assert node.flows_to == frozenset() + assert node.backend_candidates is None + + def test_basic(self): + node = _GVNode(depends_on=["a", "b"], flows_to=["c", "d"], backend_candidates=["X"]) + assert isinstance(node.depends_on, frozenset) + assert node.depends_on == frozenset(["a", "b"]) + assert isinstance(node.flows_to, frozenset) + assert node.flows_to == frozenset(["c", "d"]) + assert isinstance(node.backend_candidates, frozenset) + assert node.backend_candidates == frozenset(["X"]) + + def test_single_strings(self): + node = _GVNode(depends_on="apple", flows_to="banana", backend_candidates="coconut") + assert isinstance(node.depends_on, frozenset) + assert node.depends_on == frozenset(["apple"]) + assert isinstance(node.flows_to, frozenset) + assert node.flows_to == frozenset(["banana"]) + assert isinstance(node.backend_candidates, frozenset) + assert node.backend_candidates == frozenset(["coconut"]) + + def test_eq(self): + assert _GVNode() == _GVNode() + assert _GVNode( + depends_on=["a", "b"], + flows_to=["c", "d"], + backend_candidates="X", + ) == _GVNode( + depends_on=("b", "a"), + flows_to={"d", "c"}, + backend_candidates=["X"], + ) + + def test_repr(self): + assert repr(_GVNode()) == "[_GVNode]" + assert repr(_GVNode(depends_on="a")) == "[_GVNode b]" + assert repr(_GVNode(depends_on=["a", "b"], flows_to=["foo", "bar"])) == "[_GVNode bar,foo]" + assert repr(_GVNode(depends_on="a", flows_to="b", backend_candidates=["x", "yy"])) == "[_GVNode b @x,yy]" + + +def supporting_backends_from_node_id_dict(data: dict) -> SupportingBackendsMapper: + return lambda node_id, node: data.get(node_id) + + +class TestGraphViewer: + def test_empty(self): + graph = _GraphViewer(node_map={}) + assert list(graph.iter_nodes()) == [] + + @pytest.mark.parametrize( + ["node_map", "expected_error"], + [ + ({"a": _GVNode(flows_to="b")}, r"Inconsistent.*unknown=\{'b'\}"), + ({"b": _GVNode(depends_on="a")}, r"Inconsistent.*unknown=\{'a'\}"), + ({"a": _GVNode(flows_to="b"), "b": _GVNode()}, r"Inconsistent.*bad_links=\{\('a', 'b'\)\}"), + ({"b": _GVNode(depends_on="a"), "a": _GVNode()}, r"Inconsistent.*bad_links=\{\('a', 'b'\)\}"), + ], + ) + def test_check_consistency(self, node_map, expected_error): + with pytest.raises(GraphSplitException, match=expected_error): + _ = _GraphViewer(node_map=node_map) + + def test_immutability(self): + node_map = {"a": _GVNode(flows_to="b"), "b": _GVNode(depends_on="a")} + graph = _GraphViewer(node_map=node_map) + assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))] + + # Adding a node to the original map should not affect the graph + node_map["c"] = _GVNode() + assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))] + + # Trying to mess with internals shouldn't work either + with pytest.raises(Exception, match="does not support item assignment"): + graph._graph["c"] = _GVNode() + + assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))] + + def test_from_flat_graph_basic(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + graph = _GraphViewer.from_flat_graph( + flat, supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]}) + ) + assert sorted(graph.iter_nodes()) == [ + ("lc1", _GVNode(flows_to=["ndvi1"], backend_candidates="b1")), + ("ndvi1", _GVNode(depends_on=["lc1"])), + ] + + # TODO: test from_flat_graph with more complex graphs + + def test_from_edges(self): + graph = _GraphViewer.from_edges([("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")]) + assert sorted(graph.iter_nodes()) == [ + ("a", _GVNode(flows_to=["c"])), + ("b", _GVNode(flows_to=["d"])), + ("c", _GVNode(depends_on=["a"], flows_to=["e"])), + ("d", _GVNode(depends_on=["b"], flows_to=["e"])), + ("e", _GVNode(depends_on=["c", "d"], flows_to=["f"])), + ("f", _GVNode(depends_on=["e"])), + ] + + @pytest.mark.parametrize( + ["seed", "include_seeds", "expected"], + [ + (["a"], True, ["a"]), + (["a"], False, []), + (["c"], True, ["c", "a"]), + (["c"], False, ["a"]), + (["a", "c"], True, ["a", "c"]), + (["a", "c"], False, []), + (["c", "a"], True, ["a", "c"]), + (["c", "a"], False, []), + (["e"], True, ["e", "c", "d", "a", "b"]), + (["e"], False, ["c", "d", "a", "b"]), + (["e", "d"], True, ["d", "e", "b", "c", "a"]), + (["e", "d"], False, ["c", "b", "a"]), + (["d", "e"], True, ["d", "e", "b", "c", "a"]), + (["d", "e"], False, ["b", "c", "a"]), + (["f", "c"], True, ["c", "f", "a", "e", "d", "b"]), + (["f", "c"], False, ["e", "a", "d", "b"]), + ], + ) + def test_walk_upstream_nodes(self, seed, include_seeds, expected): + graph = _GraphViewer.from_edges( + # a b + # | | + # c d + # \ / + # e + # | + # f + [("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")] + ) + assert list(graph.walk_upstream_nodes(seed, include_seeds)) == expected + + def test_get_backend_candidates_basic(self): + graph = _GraphViewer.from_edges( + # a + # | + # b c + # \ / + # d + [("a", "b"), ("b", "d"), ("c", "d")], + supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": ["b1"], "c": ["b2"]}), + ) + assert graph.get_backend_candidates_for_node("a") == {"b1"} + assert graph.get_backend_candidates_for_node("b") == {"b1"} + assert graph.get_backend_candidates_for_node("c") == {"b2"} + assert graph.get_backend_candidates_for_node("d") == set() + + assert graph.get_backend_candidates_for_node_set(["a"]) == {"b1"} + assert graph.get_backend_candidates_for_node_set(["b"]) == {"b1"} + assert graph.get_backend_candidates_for_node_set(["c"]) == {"b2"} + assert graph.get_backend_candidates_for_node_set(["d"]) == set() + assert graph.get_backend_candidates_for_node_set(["a", "b"]) == {"b1"} + assert graph.get_backend_candidates_for_node_set(["a", "b", "c"]) == set() + assert graph.get_backend_candidates_for_node_set(["a", "b", "d"]) == set() + + def test_get_backend_candidates_none(self): + graph = _GraphViewer.from_edges( + # a + # | + # b c + # \ / + # d + [("a", "b"), ("b", "d"), ("c", "d")], + ) + assert graph.get_backend_candidates_for_node("a") is None + assert graph.get_backend_candidates_for_node("b") is None + assert graph.get_backend_candidates_for_node("c") is None + assert graph.get_backend_candidates_for_node("d") is None + + assert graph.get_backend_candidates_for_node_set(["a", "b"]) is None + assert graph.get_backend_candidates_for_node_set(["a", "b", "c"]) is None + + def test_get_backend_candidates_intersection(self): + graph = _GraphViewer.from_edges( + # a b c + # \ / \ / + # d e + # \ / + # f + [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")], + supporting_backends_mapper=supporting_backends_from_node_id_dict( + {"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]} + ), + ) + assert graph.get_backend_candidates_for_node("a") == {"b1", "b2"} + assert graph.get_backend_candidates_for_node("b") == {"b2", "b3"} + assert graph.get_backend_candidates_for_node("c") == {"b4"} + assert graph.get_backend_candidates_for_node("d") == {"b2"} + assert graph.get_backend_candidates_for_node("e") == set() + assert graph.get_backend_candidates_for_node("f") == set() + + assert graph.get_backend_candidates_for_node_set(["a", "b"]) == {"b2"} + assert graph.get_backend_candidates_for_node_set(["a", "b", "d"]) == {"b2"} + assert graph.get_backend_candidates_for_node_set(["c", "d"]) == set() + + def test_find_forsaken_nodes(self): + graph = _GraphViewer.from_edges( + # a b c + # \ / \ / + # d e + # \ / + # f + # / \ + # g h + [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f"), ("f", "g"), ("f", "h")], + supporting_backends_mapper=supporting_backends_from_node_id_dict( + {"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]} + ), + ) + assert graph.find_forsaken_nodes() == {"e", "f", "g", "h"} + + def test_find_articulation_points_basic(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + graph = _GraphViewer.from_flat_graph(flat) + assert graph.find_articulation_points() == {"lc1", "ndvi1"} + + @pytest.mark.parametrize( + ["flat", "expected"], + [ + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}}, + }, + {"lc1", "ndvi1"}, + ), + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": { + "process_id": "filter_bands", + "arguments": {"data": {"from_node": "lc1"}, "bands": ["b1"]}, + }, + "bands2": { + "process_id": "filter_bands", + "arguments": {"data": {"from_node": "lc1"}, "bands": ["b2"]}, + }, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, + }, + "save1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "merge1"}, "format": "GTiff"}, + }, + }, + {"lc1", "merge1", "save1"}, + ), + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + }, + "save1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "merge1"}, "format": "GTiff"}, + }, + }, + {"lc1", "lc2", "merge1", "save1"}, + ), + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": { + "process_id": "filter_bands", + "arguments": {"data": {"from_node": "lc1"}, "bands": ["b1"]}, + }, + "bbox1": { + "process_id": "filter_spatial", + "arguments": {"data": {"from_node": "bands1"}}, + }, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "bbox1"}}, + }, + "save1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "merge1"}}, + }, + }, + {"lc1", "merge1", "save1"}, + ), + ], + ) + def test_find_articulation_points(self, flat, expected): + graph = _GraphViewer.from_flat_graph(flat) + assert graph.find_articulation_points() == expected + + def test_split_at_minimal(self): + graph = _GraphViewer.from_edges( + [("a", "b")], supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}) + ) + # Split at a + up, down = graph.split_at("a") + assert sorted(up.iter_nodes()) == [ + ("a", _GVNode(backend_candidates=["A"])), + ] + assert sorted(down.iter_nodes()) == [ + ("a", _GVNode(flows_to=["b"])), + ("b", _GVNode(depends_on=["a"])), + ] + # Split at b + up, down = graph.split_at("b") + assert sorted(up.iter_nodes()) == [ + ("a", _GVNode(flows_to=["b"], backend_candidates=["A"])), + ("b", _GVNode(depends_on=["a"])), + ] + assert sorted(down.iter_nodes()) == [ + ("b", _GVNode()), + ] + + def test_split_at_basic(self): + graph = _GraphViewer.from_edges( + [("a", "b"), ("b", "c")], + supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}), + ) + up, down = graph.split_at("b") + assert sorted(up.iter_nodes()) == [ + ("a", _GVNode(flows_to=["b"], backend_candidates=["A"])), + ("b", _GVNode(depends_on=["a"])), + ] + assert sorted(down.iter_nodes()) == [ + ("b", _GVNode(flows_to=["c"])), + ("c", _GVNode(depends_on=["b"])), + ] + + def test_split_at_complex(self): + graph = _GraphViewer.from_edges( + # a + # / \ + # b c X + # \ / \ | + # d e f Y + # \ / + # g + [("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")] + ) + up, down = graph.split_at("e") + assert sorted(up.iter_nodes()) == sorted( + _GraphViewer.from_edges([("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e")]).iter_nodes() + ) + assert sorted(down.iter_nodes()) == sorted( + _GraphViewer.from_edges([("e", "g"), ("f", "g"), ("X", "Y")]).iter_nodes() + ) + + def test_split_at_non_articulation_point(self): + graph = _GraphViewer.from_edges( + # a + # /| + # b | + # \| + # c + [("a", "b"), ("b", "c"), ("a", "c")] + ) + + with pytest.raises(GraphSplitException, match="not an articulation point"): + _ = graph.split_at("b") + + # These should still work + up, down = graph.split_at("a") + assert sorted(up.iter_nodes()) == [ + ("a", _GVNode()), + ] + assert sorted(down.iter_nodes()) == [ + ("a", _GVNode(flows_to=["b", "c"])), + ("b", _GVNode(depends_on=["a"], flows_to=["c"])), + ("c", _GVNode(depends_on=["a", "b"])), + ] + + up, down = graph.split_at("c") + assert sorted(up.iter_nodes()) == [ + ("a", _GVNode(flows_to=["b", "c"])), + ("b", _GVNode(depends_on=["a"], flows_to=["c"])), + ("c", _GVNode(depends_on=["a", "b"])), + ] + assert sorted(down.iter_nodes()) == [ + ("c", _GVNode()), + ] + + def test_split_at_multiple_empty(self): + graph = _GraphViewer.from_edges([("a", "b")]) + result = graph.split_at_multiple([]) + assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == { + None: [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))], + } + + def test_split_at_multiple_single(self): + graph = _GraphViewer.from_edges([("a", "b"), ("b", "c")]) + result = graph.split_at_multiple(["b"]) + assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == { + "b": [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))], + None: [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))], + } + + def test_split_at_multiple_basic(self): + graph = _GraphViewer.from_edges( + [("a", "b"), ("b", "c"), ("c", "d")], + supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}), + ) + result = graph.split_at_multiple(["b", "c"]) + assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == { + "b": [("a", _GVNode(flows_to="b", backend_candidates="A")), ("b", _GVNode(depends_on="a"))], + "c": [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))], + None: [("c", _GVNode(flows_to="d")), ("d", _GVNode(depends_on="c"))], + } + + def test_split_at_multiple_invalid(self): + """Split nodes should be in downstream order""" + graph = _GraphViewer.from_edges( + [("a", "b"), ("b", "c"), ("c", "d")], + ) + # Downstream order: works + _ = graph.split_at_multiple(["b", "c"]) + # Upstream order: fails + with pytest.raises(GraphSplitException, match="Invalid node id 'b'"): + _ = graph.split_at_multiple(["c", "b"]) + + def test_produce_split_locations_simple(self): + """Simple produce_split_locations use case: no need for splits""" + flat = { + # lc1 + # | + # ndvi1 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + graph = _GraphViewer.from_flat_graph( + flat, supporting_backends=supporting_backends_from_node_id_dict({"lc1": "b1"}) + ) + assert list(graph.produce_split_locations()) == [[]] + + def test_produce_split_locations_merge_basic(self): + """ + Basic produce_split_locations use case: + two load collections on different backends and a merge + """ + flat = { + # lc1 lc2 + # \ / + # merge1 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + }, + } + graph = _GraphViewer.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) + assert sorted(graph.produce_split_locations()) == [["lc1"], ["lc2"]] + + def test_produce_split_locations_merge_longer(self): + flat = { + # lc1 lc2 + # | | + # bands1 bands2 + # \ / + # merge1 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _GraphViewer.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) + assert sorted(graph.produce_split_locations(limit=2)) == [["bands1"], ["bands2"]] + assert list(graph.produce_split_locations(limit=4)) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]] + + def test_produce_split_locations_merge_longer_triangle(self): + flat = { + # lc1 + # / | + # bands1 | lc2 + # \ | | + # mask1 bands2 + # \ / + # merge1 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "mask1": { + "process_id": "mask", + "arguments": {"data": {"from_node": "bands1"}, "mask": {"from_node": "lc1"}}, + }, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "mask1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _GraphViewer.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) + assert list(graph.produce_split_locations(limit=4)) == [["bands2"], ["mask1"], ["lc2"], ["lc1"]] + + def test_produce_split_locations_allow_split(self): + """Usage of custom allow_list predicate""" + flat = { + # lc1 lc2 + # | | + # bands1 bands2 + # \ / + # merge1 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _GraphViewer.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) + assert list(graph.produce_split_locations()) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]] + assert list(graph.produce_split_locations(allow_split=lambda n: n not in {"bands1", "lc2"})) == [ + ["bands2"], + ["lc1"], + ] + + +class TestDeepGraphSplitter: + def test_no_split(self): + splitter = DeepGraphSplitter(supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]})) + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc1", "ndvi1"}, + primary_backend_id="b1", + secondary_graphs=[], + ) + + def test_simple_split(self): + """ + Most simple split use case: two load_collections from different backends, merged. + """ + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}) + ) + flat = { + # lc1 lc2 + # \ / + # merge + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc1", "lc2", "merge"}, + primary_backend_id="b2", + secondary_graphs=[ + _PGSplitSubGraph( + split_node="lc1", + node_ids={"lc1"}, + backend_id="b1", + ) + ], + ) + + def test_simple_deep_split(self): + """ + Simple deep split use case: + two load_collections from different backends, with some additional filtering, merged. + """ + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}) + ) + flat = { + # lc1 lc2 + # | | + # bands1 temporal2 + # \ / + # merge + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "temporal2": { + "process_id": "filter_temporal", + "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]}, + }, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc2", "temporal2", "bands1", "merge"}, + primary_backend_id="b2", + secondary_graphs=[_PGSplitSubGraph(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")], + ) + + def test_shallow_triple_split(self): + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + ) + flat = { + # lc1 lc2 lc3 + # \ / / + # merge1 / + # \ / + # merge2 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc3": {"process_id": "load_collection", "arguments": {"id": "S3"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + }, + "merge2": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "merge1"}, "cube2": {"from_node": "lc3"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc1", "lc2", "lc3", "merge1", "merge2"}, + primary_backend_id="b2", + secondary_graphs=[ + _PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1"), + _PGSplitSubGraph(split_node="lc3", node_ids={"lc3"}, backend_id="b3"), + ], + ) + + def test_triple_split(self): + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + ) + flat = { + # lc1 lc2 lc3 + # | | | + # bands1 temporal2 spatial3 + # \ / / + # merge1 / + # \ / + # merge2 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc3": {"process_id": "load_collection", "arguments": {"id": "S3"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "temporal2": { + "process_id": "filter_temporal", + "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]}, + }, + "spatial3": {"process_id": "filter_spatial", "arguments": {"data": {"from_node": "lc3"}, "extent": "EU"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + }, + "merge2": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "merge1"}, "cube2": {"from_node": "spatial3"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"merge2", "merge1", "lc3", "spatial3"}, + primary_backend_id="b3", + secondary_graphs=[ + _PGSplitSubGraph(split_node="bands1", node_ids={"bands1", "lc1"}, backend_id="b1"), + _PGSplitSubGraph( + split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2" + ), + ], + ) + + @pytest.mark.parametrize( + ["primary_backend", "secondary_graph"], + [ + ("b1", _PGSplitSubGraph(split_node="lc2", node_ids={"lc2"}, backend_id="b2")), + ("b2", _PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1")), + ], + ) + def test_split_with_primary_backend(self, primary_backend, secondary_graph): + """Test `primary_backend` argument of DeepGraphSplitter""" + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + primary_backend=primary_backend, + ) + flat = { + # lc1 lc2 + # \ / + # merge + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc1", "lc2", "merge"}, + primary_backend_id=primary_backend, + secondary_graphs=[secondary_graph], + ) + + @pytest.mark.parametrize( + ["split_deny_list", "split_node", "primary_node_ids", "secondary_node_ids"], + [ + ({}, "temporal2", {"lc1", "bands1", "temporal2", "merge"}, {"lc2", "temporal2"}), + ({"filter_bands", "filter_temporal"}, "lc2", {"lc1", "lc2", "bands1", "temporal2", "merge"}, {"lc2"}), + ], + ) + def test_split_deny_list(self, split_deny_list, split_node, primary_node_ids, secondary_node_ids): + """ + Simple deep split use case: + two load_collections from different backends, with some additional filtering, merged. + """ + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + primary_backend="b1", + split_deny_list=split_deny_list, + ) + flat = { + # lc1 lc2 + # | | + # bands1 temporal2 + # \ / + # merge + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "temporal2": { + "process_id": "filter_temporal", + "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]}, + }, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids=primary_node_ids, + primary_backend_id="b1", + secondary_graphs=[_PGSplitSubGraph(split_node=split_node, node_ids=secondary_node_ids, backend_id="b2")], + )