Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jitter for manager heartbeats and updates #889

Merged
merged 8 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qcarchivetesting/qcarchivetesting/testing_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
qcf_config["service_frequency"] = 5
qcf_config["loglevel"] = "DEBUG"
qcf_config["heartbeat_frequency"] = 3
qcf_config["heartbeat_frequency_jitter"] = 0.0
qcf_config["heartbeat_max_missed"] = 2

qcf_config["database"] = {"pool_size": 0}
Expand Down
9 changes: 6 additions & 3 deletions qcfractal/qcfractal/components/managers/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, root_socket: SQLAlchemySocket):
self._logger = logging.getLogger(__name__)

self._manager_heartbeat_frequency = root_socket.qcf_config.heartbeat_frequency
self._manager_heartbeat_frequency_jitter = root_socket.qcf_config.heartbeat_frequency_jitter
self._manager_max_missed_heartbeats = root_socket.qcf_config.heartbeat_max_missed

with self.root_socket.session_scope() as session:
Expand Down Expand Up @@ -292,11 +293,13 @@ def check_manager_heartbeats(self, session: Session) -> None:
----------
session
An existing SQLAlchemy session to use.
job_progress
An object used to report the current job progress and status
"""
self._logger.debug("Checking manager heartbeats")
manager_window = self._manager_max_missed_heartbeats * self._manager_heartbeat_frequency

# Take into account the maximum jitter allowed
manager_window = self._manager_max_missed_heartbeats * (
self._manager_heartbeat_frequency + self._manager_heartbeat_frequency_jitter
)
dt = now_at_utc() - timedelta(seconds=manager_window)

dead_managers = self.deactivate(modified_before=dt, reason="missing heartbeat", session=session)
Expand Down
8 changes: 7 additions & 1 deletion qcfractal/qcfractal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,17 @@ class FractalConfig(ConfigBase):
service_frequency: int = Field(60, description="The frequency at which to update services (in seconds)")
max_active_services: int = Field(20, description="The maximum number of concurrent active services")
heartbeat_frequency: int = Field(
1800, description="The frequency (in seconds) to check the heartbeat of compute managers"
1800,
description="The frequency (in seconds) to check the heartbeat of compute managers",
gt=0,
)
heartbeat_frequency_jitter: int = Field(
0.1, description="Jitter fraction to be applied to the heartbeat frequency", ge=0
)
heartbeat_max_missed: int = Field(
5,
description="The maximum number of heartbeats that a compute manager can miss. If more are missed, the worker is considered dead",
ge=0,
)

# Access logging
Expand Down
1 change: 1 addition & 0 deletions qcfractal/qcfractal/flask_app/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def get_public_server_information():
public_info = {
"name": qcf_cfg.name,
"manager_heartbeat_frequency": qcf_cfg.heartbeat_frequency,
"manager_heartbeat_frequency_jitter": qcf_cfg.heartbeat_frequency_jitter,
"manager_heartbeat_max_missed": qcf_cfg.heartbeat_max_missed,
"version": qcfractal_version,
"api_limits": qcf_cfg.api_limits.dict(),
Expand Down
2 changes: 2 additions & 0 deletions qcfractal/qcfractal/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
qcf_cfg["hide_internal_errors"] = False
qcf_cfg["service_frequency"] = 10
qcf_cfg["heartbeat_frequency"] = 5
qcf_cfg["heartbeat_frequency_jitter"] = 0.0
qcf_cfg["heartbeat_max_missed"] = 3
qcf_cfg["api"] = {
"host": host,
Expand Down Expand Up @@ -262,6 +263,7 @@ def __init__(
parsl_run_dir=parsl_run_dir,
cluster="snowflake_compute",
update_frequency=5,
update_frequency_jitter=0.0,
server=FractalServerSettings(
fractal_uri=uri,
verify=False,
Expand Down
9 changes: 6 additions & 3 deletions qcfractalcompute/qcfractalcompute/compute_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from qcportal.managers import ManagerName
from qcportal.metadata_models import TaskReturnMetadata
from qcportal.record_models import RecordTask
from qcportal.utils import seconds_to_hms
from qcportal.utils import seconds_to_hms, apply_jitter
from . import __version__
from .apps.models import AppTaskResult
from .compress import compress_result
Expand Down Expand Up @@ -185,6 +185,7 @@ def __init__(self, config: FractalComputeConfig):
# Pull server info
self.server_info = self.client.get_server_information()
self.heartbeat_frequency = self.server_info["manager_heartbeat_frequency"]
self.heartbeat_frequency_jitter = self.server_info.get("manager_heartbeat_frequency_jitter", 0.0)

self.client.activate(__version__, self.all_program_info, tags=self.all_queue_tags)

Expand Down Expand Up @@ -288,13 +289,15 @@ def scheduler_update():
if not manual_updates:
self.update(new_tasks=True)
if not self._is_stopping:
self.scheduler.enter(self.manager_config.update_frequency, 1, scheduler_update)
delay = apply_jitter(self.manager_config.update_frequency, self.manager_config.update_frequency_jitter)
self.scheduler.enter(delay, 1, scheduler_update)

def scheduler_heartbeat():
if not manual_updates:
self.heartbeat()
if not self._is_stopping:
self.scheduler.enter(self.heartbeat_frequency, 1, scheduler_heartbeat)
delay = apply_jitter(self.heartbeat_frequency, self.heartbeat_frequency_jitter)
self.scheduler.enter(delay, 1, scheduler_heartbeat)

self.logger.info("Compute Manager successfully started.")

Expand Down
8 changes: 8 additions & 0 deletions qcfractalcompute/qcfractalcompute/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ class FractalComputeConfig(BaseModel):
"itself down to maintain integrity between it and the Fractal Server. Units of seconds",
gt=0,
)
update_frequency_jitter: float = Field(
0.1,
description="The update frequency will be modified by up to a certain amount for each request. The "
"update_frequency_jitter represents a fraction of the update_frequency to allow as a max. "
"Ie, update_frequency=60, and jitter=0.1, updates will happen between 54 and 66 seconds. "
"This helps with spreading out server load.",
ge=0,
)

max_idle_time: Optional[int] = Field(
None,
Expand Down
1 change: 1 addition & 0 deletions qcfractalcompute/qcfractalcompute/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
parsl_run_dir=parsl_run_dir,
cluster="mock_compute",
update_frequency=1,
update_frequency_jitter=0.0,
server=FractalServerSettings(
fractal_uri=uri,
verify=False,
Expand Down
15 changes: 12 additions & 3 deletions qcportal/qcportal/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _send_request(self, req: requests.Request, allow_retries: bool = True) -> re

Parameters
----------
prep_req
req
A prepared request to send
allow_retries
If true, attempts to retry on certain kinds of errors
Expand Down Expand Up @@ -423,6 +423,7 @@ def _request(
url_params: Optional[Dict[str, Any]] = None,
internal_retry: Optional[bool] = True,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> requests.Response:
# If refresh token has expired, log in again
if self._jwt_refresh_exp and self._jwt_refresh_exp < time.time():
Expand All @@ -433,7 +434,9 @@ def _request(
self._refresh_JWT_token()

full_uri = self.address + endpoint
req = requests.Request(method=method.upper(), url=full_uri, data=body, params=url_params)
req = requests.Request(
method=method.upper(), url=full_uri, data=body, params=url_params, headers=additional_headers
)
r = self._send_request(req, allow_retries=allow_retries)

# If JWT token expired, automatically renew it and retry once. This should have been caught above,
Expand Down Expand Up @@ -468,6 +471,7 @@ def make_request(
body: Optional[Union[_T, Dict[str, Any]]] = None,
url_params: Optional[Union[_U, Dict[str, Any]]] = None,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> _V:
# If body_model or url_params_model are None, then use the type given
if body_model is None and body is not None:
Expand All @@ -489,7 +493,12 @@ def make_request(
parsed_url_params = parsed_url_params.dict()

r = self._request(
method, endpoint, body=serialized_body, url_params=parsed_url_params, allow_retries=allow_retries
method,
endpoint,
body=serialized_body,
url_params=parsed_url_params,
allow_retries=allow_retries,
additional_headers=additional_headers,
)
d = deserialize(r.content, r.headers["Content-Type"])

Expand Down
1 change: 1 addition & 0 deletions qcportal/qcportal/manager_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _update_on_server(self, manager_update: ManagerUpdateBody) -> None:
None,
body=manager_update,
allow_retries=False,
additional_headers={"Connection": "close"},
)

def activate(
Expand Down
6 changes: 6 additions & 0 deletions qcportal/qcportal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import logging
import math
import random
import re
import time
from contextlib import contextmanager, redirect_stderr, redirect_stdout
Expand Down Expand Up @@ -449,3 +450,8 @@ def update_nested_dict(d: Dict[str, Any], u: Dict[str, Any]):
else:
d[k] = v
return d


def apply_jitter(t: Union[int, float], jitter_fraction: float) -> float:
f = random.uniform(-jitter_fraction, jitter_fraction)
return max(t * (1 + f), 0.0)