Skip to content

Commit

Permalink
[Data] Adding in per node metrics (ray-project#49705)
Browse files Browse the repository at this point in the history
## Why are these changes needed?
Adding in per node metrics will allow us to debug more efficiently as we
can see a more granular view of what is happening at the per node level.

Currently for Ray Data we collect a number of metrics that we would want
to group by node, but we do not report node for these metrics. The stats
collector actor aggregates these across multiple nodes. This PR adds per
node metrics to `OpRuntimeMetrics` which will not be aggregated across
multiple nodes so we can visualize this data segmented by node in the
data dashboard.

## Example
### Script
```python
import ray
import time

def f(x):
    time.sleep(0.1)
    return x

file_path = "s3://air-example-data-2/100G-xgboost-data.parquet/"
ds = ray.data.read_parquet(file_path).limit(10_000_000)
ds = ds.map_batches(f)

for _ in ds.iter_batches():
    pass
```

### Output
```
(base) ray@ip-10-0-61-222:~/default$ python gen_metrics.py 
2025-01-16 16:33:49,968 INFO worker.py:1654 -- Connecting to existing Ray cluster at address: 10.0.61.222:6379...
2025-01-16 16:33:49,977 INFO worker.py:1832 -- Connected to Ray cluster. View the dashboard at https://session-64eiepsal97ynjwq1gb53c43vb.i.anyscaleuserdata-staging.com 
2025-01-16 16:33:49,979 INFO packaging.py:366 -- Pushing file package 'gcs://_ray_pkg_e7896c7ed49efce702fc2ded295073e96fe54a3a.zip' (0.00MiB) to Ray cluster...
2025-01-16 16:33:49,979 INFO packaging.py:379 -- Successfully pushed file package 'gcs://_ray_pkg_e7896c7ed49efce702fc2ded295073e96fe54a3a.zip'.
2025-01-16 16:33:51,418 INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-01-16_16-33-18_905648_2451/logs/ray-data
2025-01-16 16:33:51,418 INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[PartitionFiles] -> TaskPoolMapOperator[ReadFiles] -> LimitOperator[limit=10000000] -> TaskPoolMapOperator[MapBatches(f)]
(autoscaler +1m19s) Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.                                                                                                                                                                                                                                                     
✔️  Dataset execution finished in 113.05 seconds: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10.0M/10.0M [01:53<00:00, 88.5k row/s] 
- ListFiles: Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 0.0B object store: : 1.00k row [01:53, 8.85 row/s]                                                                                                                                                                                                                                                                       
- PartitionFiles: Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 87.3KB object store: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.00k/1.00k [01:53<00:00, 8.85 row/s]
- ReadFiles: Tasks: 0; Queued blocks: 311; Resources: 0.0 CPU, 2.4GB object store: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17.9M/17.9M [01:53<00:00, 159k row/s]
- limit=10000000: Tasks: 0; Queued blocks: 19; Resources: 0.0 CPU, 0.0B object store: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10.0M/10.0M [01:53<00:00, 88.5k row/s]
- MapBatches(f): Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 1.2GB object store: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10.0M/10.0M [01:53<00:00, 88.5k row/s
```

### New Charts
<img width="1462" alt="image"
src="https://github.com/user-attachments/assets/218183df-243f-4c84-9af9-cc362fac0b7e"
/>
<img width="1465" alt="image"
src="https://github.com/user-attachments/assets/4bdfef16-c773-45c2-bc56-512df70ba0c4"
/>
<img width="1468" alt="image"
src="https://github.com/user-attachments/assets/da19ac5f-33f8-46fe-9677-afbe08a55af1"
/>
<img width="1471" alt="image"
src="https://github.com/user-attachments/assets/72a85d24-ea4e-4c7b-9d96-244325d6333d"
/>

---------

Signed-off-by: Matthew Owen <mowen@anyscale.com>
  • Loading branch information
omatthew98 authored Feb 28, 2025
1 parent 79aa176 commit 53678c1
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
fill=0,
stack=False,
),
# TODO(mowen): Determine if we actually need bytes allocated since its not being used.
Panel(
id=2,
title="Bytes Allocated",
Expand Down Expand Up @@ -285,6 +286,38 @@
fill=0,
stack=False,
),
Panel(
id=43,
title="Output Bytes from Finished Tasks / Second (by Node)",
description=(
"Byte size of output blocks from finished tasks per second, grouped by node."
),
unit="Bps",
targets=[
Target(
expr="sum(rate(ray_data_bytes_outputs_of_finished_tasks_per_node{{{global_filters}}}[1m])) by (dataset, node_ip)",
legend="Bytes output / Second: {{dataset}}, {{node_ip}}",
)
],
fill=0,
stack=False,
),
Panel(
id=48,
title="Blocks from Finished Tasks / Second (by Node)",
description=(
"Number of output blocks from finished tasks per second, grouped by node."
),
unit="blocks/s",
targets=[
Target(
expr="sum(rate(ray_data_blocks_outputs_of_finished_tasks_per_node{{{global_filters}}}[1m])) by (dataset, node_ip)",
legend="Blocks output / Second: {{dataset}}, {{node_ip}}",
)
],
fill=0,
stack=False,
),
# Ray Data Metrics (Tasks)
Panel(
id=29,
Expand Down Expand Up @@ -342,6 +375,20 @@
fill=0,
stack=False,
),
Panel(
id=46,
title="Task Throughput (by Node)",
description="Number of finished tasks per second, grouped by node.",
unit="tasks/s",
targets=[
Target(
expr="sum(rate(ray_data_num_tasks_finished_per_node{{{global_filters}}}[1m])) by (dataset, node_ip)",
legend="Finished Tasks: {{dataset}}, {{node_ip}}",
)
],
fill=0,
stack=False,
),
Panel(
id=33,
title="Failed Tasks",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from collections import defaultdict
from dataclasses import Field, dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional
Expand All @@ -8,6 +9,7 @@
from ray.data._internal.execution.bundle_queue import create_bundle_queue
from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
from ray.data._internal.memory_tracing import trace_allocation
from ray.data.block import BlockMetadata

if TYPE_CHECKING:
from ray.data._internal.execution.interfaces.physical_operator import (
Expand All @@ -24,6 +26,8 @@

_METRICS: List["MetricDefinition"] = []

NODE_UNKNOWN = "unknown"


class MetricsGroup(Enum):
INPUTS = "inputs"
Expand Down Expand Up @@ -104,6 +108,13 @@ class RunningTaskInfo:
start_time: float


@dataclass
class NodeMetrics:
num_tasks_finished: int = 0
bytes_outputs_of_finished_tasks: int = 0
blocks_outputs_of_finished_tasks: int = 0


class OpRuntimesMetricsMeta(type):
def __init__(cls, name, bases, dict):
# NOTE: `Field.name` isn't set until the dataclass is created, so we can't
Expand All @@ -125,6 +136,14 @@ def __init__(cls, name, bases, dict):
_METRICS.append(metric)


def node_id_from_block_metadata(meta: BlockMetadata) -> str:
if meta.exec_stats is not None and meta.exec_stats.node_id is not None:
node_id = meta.exec_stats.node_id
else:
node_id = NODE_UNKNOWN
return node_id


class TaskDurationStats:
"""
Tracks the running mean and variance incrementally with Welford's algorithm
Expand Down Expand Up @@ -352,6 +371,9 @@ def __init__(self, op: "PhysicalOperator"):
self._pending_task_inputs = create_bundle_queue()
self._op_task_duration_stats = TaskDurationStats()

self._per_node_metrics: Dict[str, NodeMetrics] = defaultdict(NodeMetrics)
self._per_node_metrics_enabled: bool = op.data_context.enable_per_node_metrics

@property
def extra_metrics(self) -> Dict[str, Any]:
"""Return a dict of extra metrics."""
Expand Down Expand Up @@ -583,6 +605,15 @@ def on_task_output_generated(self, task_index: int, output: RefBundle):
self.rows_task_outputs_generated += meta.num_rows
trace_allocation(block_ref, "operator_output")

# Update per node metrics
if self._per_node_metrics_enabled:
for _, meta in output.blocks:
node_id = node_id_from_block_metadata(meta)
node_metrics = self._per_node_metrics[node_id]

node_metrics.bytes_outputs_of_finished_tasks += meta.size_bytes
node_metrics.blocks_outputs_of_finished_tasks += 1

def on_task_finished(self, task_index: int, exception: Optional[Exception]):
"""Callback when a task is finished."""
self.num_tasks_running -= 1
Expand Down Expand Up @@ -619,5 +650,19 @@ def on_task_finished(self, task_index: int, exception: Optional[Exception]):

self.obj_store_mem_freed += total_input_size

# Update per node metrics
if self._per_node_metrics_enabled:
node_ids = set()
for _, meta in inputs.blocks:
node_id = node_id_from_block_metadata(meta)
node_metrics = self._per_node_metrics[node_id]

# Stats to update once per node id or if node id is unknown
if node_id not in node_ids or node_id == NODE_UNKNOWN:
node_metrics.num_tasks_finished += 1

# Keep track of node ids to ensure we don't double count
node_ids.add(node_id)

inputs.destroy_if_owned()
del self._running_tasks[task_index]
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def __init__(
self._started = False
self._in_task_submission_backpressure = False
self._in_task_output_backpressure = False
self._metrics = OpRuntimeMetrics(self)
self._estimated_num_output_bundles = None
self._estimated_output_num_rows = None
self._execution_completed = False
Expand All @@ -210,6 +209,8 @@ def __init__(
self._logical_operators: List[LogicalOperator] = []
self._data_context = data_context
self._id = str(uuid.uuid4())
# Initialize metrics after data_context is set
self._metrics = OpRuntimeMetrics(self)

def __reduce__(self):
raise ValueError("Operator is not serializable.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class RefBundle:
"""

# The size_bytes must be known in the metadata, num_rows is optional.
blocks: Tuple[Tuple[ObjectRef[Block], BlockMetadata]]
blocks: Tuple[Tuple[ObjectRef[Block], BlockMetadata], ...]

# Whether we own the blocks (can safely destroy them).
owns_blocks: bool
Expand Down
5 changes: 4 additions & 1 deletion python/ray/data/_internal/execution/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ray.data.context import DataContext

if TYPE_CHECKING:
from ray.data._internal.execution.streaming_executor_state import OpState
from ray.data._internal.execution.streaming_executor_state import Topology


Expand Down Expand Up @@ -101,7 +102,9 @@ def __init__(
)
)

def _estimate_object_store_memory(self, op, state) -> int:
def _estimate_object_store_memory(
self, op: "PhysicalOperator", state: "OpState"
) -> int:
# Don't count input refs towards dynamic memory usage, as they have been
# pre-created already outside this execution.
if isinstance(op, InputDataBuffer):
Expand Down
86 changes: 82 additions & 4 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import logging
import threading
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
from uuid import uuid4

import numpy as np
Expand All @@ -14,7 +15,9 @@
from ray.data._internal.block_list import BlockList
from ray.data._internal.execution.interfaces.op_runtime_metrics import (
MetricsGroup,
NodeMetrics,
OpRuntimeMetrics,
NODE_UNKNOWN,
)
from ray.data._internal.util import capfirst
from ray.data.block import BlockMetadata, BlockStats
Expand Down Expand Up @@ -158,6 +161,9 @@ def __init__(self, max_stats=1000):
# Dataset metadata to be queried directly by DashboardHead api.
self.datasets: Dict[str, Any] = {}

# Cache of calls to ray.nodes() to prevent unnecessary network calls
self._ray_nodes_cache: Dict[str, str] = {}

# Ray Data dashboard metrics
# Everything is a gauge because we need to reset all of
# a dataset's metrics to 0 after each finishes execution.
Expand Down Expand Up @@ -249,6 +255,9 @@ def __init__(self, max_stats=1000):
)
)

# Per Node metrics
self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics()

iter_tag_keys = ("dataset",)
self.iter_total_blocked_s = Gauge(
"data_iter_total_blocked_seconds",
Expand Down Expand Up @@ -282,6 +291,17 @@ def _create_prometheus_metrics_for_execution_metrics(
)
return metrics

def _create_prometheus_metrics_for_per_node_metrics(self) -> Dict[str, Gauge]:
metrics = {}
for field in fields(NodeMetrics):
metric_name = f"data_{field.name}_per_node"
metrics[field.name] = Gauge(
metric_name,
description="",
tag_keys=("dataset", "node_ip"),
)
return metrics

def record_start(self, stats_uuid):
self.start_time[stats_uuid] = time.perf_counter()
self.fifo_queue.append(stats_uuid)
Expand Down Expand Up @@ -334,6 +354,7 @@ def update_execution_metrics(
op_metrics: List[Dict[str, Union[int, float]]],
operator_tags: List[str],
state: Dict[str, Any],
per_node_metrics: Optional[Dict[str, Dict[str, Union[int, float]]]] = None,
):
for stats, operator_tag in zip(op_metrics, operator_tags):
tags = self._create_tags(dataset_tag, operator_tag)
Expand Down Expand Up @@ -364,10 +385,37 @@ def update_execution_metrics(
for field_name, prom_metric in self.execution_metrics_misc.items():
prom_metric.set(stats.get(field_name, 0), tags)

# Update per node metrics if they exist, the creation of these metrics is controlled
# by the _data_context.enable_per_node_metrics flag in the streaming executor but
# that is not exposed in the _StatsActor so here we simply check if the metrics exist
# and if so, update them
if per_node_metrics is not None:
for node_id, node_metrics in per_node_metrics.items():
# Translate node_id into node_name (the node ip), cache node info
if node_id not in self._ray_nodes_cache:
# Rebuilding this cache will fetch all nodes, this
# only needs to be done up to once per loop
self._rebuild_ray_nodes_cache()

node_ip = self._ray_nodes_cache.get(node_id, NODE_UNKNOWN)

tags = self._create_tags(dataset_tag=dataset_tag, node_ip_tag=node_ip)
for metric_name, metric_value in node_metrics.items():
prom_metric = self.per_node_metrics[metric_name]
prom_metric.set(metric_value, tags)

# This update is called from a dataset's executor,
# so all tags should contain the same dataset
self.update_dataset(dataset_tag, state)

def _rebuild_ray_nodes_cache(self) -> None:
current_nodes = ray.nodes()
for node in current_nodes:
node_id = node.get("NodeID", None)
node_name = node.get("NodeName", None)
if node_id is not None and node_name is not None:
self._ray_nodes_cache[node_id] = node_name

def update_iteration_metrics(
self,
stats: "DatasetStats",
Expand Down Expand Up @@ -404,10 +452,17 @@ def get_datasets(self, job_id: Optional[str] = None):
return self.datasets
return {k: v for k, v in self.datasets.items() if v["job_id"] == job_id}

def _create_tags(self, dataset_tag: str, operator_tag: Optional[str] = None):
def _create_tags(
self,
dataset_tag: str,
operator_tag: Optional[str] = None,
node_ip_tag: Optional[str] = None,
):
tags = {"dataset": dataset_tag}
if operator_tag is not None:
tags["operator"] = operator_tag
if node_ip_tag is not None:
tags["node_ip"] = node_ip_tag
return tags


Expand Down Expand Up @@ -549,6 +604,28 @@ def _run_update_loop():

# Execution methods

def _aggregate_per_node_metrics(
self, op_metrics: List[OpRuntimeMetrics]
) -> Optional[Mapping[str, Mapping[str, Union[int, float]]]]:
"""
Aggregate per-node metrics from a list of OpRuntimeMetrics objects.
If per-node metrics are disabled in the current DataContext, returns None.
Otherwise, it sums up all NodeMetrics fields across the provided metrics and
returns a nested dictionary mapping each node ID to a dict of field values.
"""
if not DataContext.get_current().enable_per_node_metrics:
return None

aggregated_by_node = defaultdict(lambda: defaultdict(int))
for metrics in op_metrics:
for node_id, node_metrics in metrics._per_node_metrics.items():
agg_node_metrics = aggregated_by_node[node_id]
for f in fields(NodeMetrics):
agg_node_metrics[f.name] += getattr(node_metrics, f.name)

return aggregated_by_node

def update_execution_metrics(
self,
dataset_tag: str,
Expand All @@ -558,7 +635,8 @@ def update_execution_metrics(
force_update: bool = False,
):
op_metrics_dicts = [metric.as_dict() for metric in op_metrics]
args = (dataset_tag, op_metrics_dicts, operator_tags, state)
per_node_metrics = self._aggregate_per_node_metrics(op_metrics)
args = (dataset_tag, op_metrics_dicts, operator_tags, state, per_node_metrics)
if force_update:
self._stats_actor().update_execution_metrics.remote(*args)
else:
Expand Down
Loading

0 comments on commit 53678c1

Please sign in to comment.