Skip to content

Commit

Permalink
Add error handling in various to_json() funcs. Break JSON write out f…
Browse files Browse the repository at this point in the history
…rom pickle write so that we maintain writing pickles if JSON fails
  • Loading branch information
KaspariK committed Oct 24, 2024
1 parent c988990 commit 2c41653
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 98 deletions.
1 change: 0 additions & 1 deletion tests/serialize/runstate/dynamodb_state_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def large_object():
}


# TODO: Add better test for to_json?
@pytest.mark.usefixtures("store", "small_object", "large_object")
class TestDynamoDBStateStore:
def test_save(self, store, small_object, large_object):
Expand Down
22 changes: 15 additions & 7 deletions tron/actioncommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from io import StringIO
from shlex import quote
from typing import Optional

from tron.config import schema
from tron.serialize import filehandler
Expand Down Expand Up @@ -203,13 +204,20 @@ def __ne__(self, other):
return not self == other

@staticmethod
def to_json(state_data: dict) -> str:
return json.dumps(
{
"status_path": state_data["status_path"],
"exec_path": state_data["exec_path"],
}
)
def to_json(state_data: dict) -> Optional[str]:
try:
return json.dumps(
{
"status_path": state_data["status_path"],
"exec_path": state_data["exec_path"],
}
)
except KeyError as e:
log.error(f"Missing key in state_data: {e}")
return None
except Exception as e:
log.error(f"Error serializing SubprocessActionRunnerFactory to JSON: {e}")
return None


def create_action_runner_factory_from_config(config):
Expand Down
61 changes: 35 additions & 26 deletions tron/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,39 +53,48 @@ def copy(self):
return ActionCommandConfig(**self.state_data)

@staticmethod
def to_json(state_data: dict) -> str:
def to_json(state_data: dict) -> Optional[str]:
"""Serialize the ActionCommandConfig instance to a JSON string."""

def serialize_namedtuple(obj):
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
return obj._asdict()
return obj

return json.dumps(
{
"command": state_data["command"],
"cpus": state_data["cpus"],
"mem": state_data["mem"],
"disk": state_data["disk"],
"cap_add": state_data["cap_add"],
"cap_drop": state_data["cap_drop"],
"constraints": list(state_data["constraints"]),
"docker_image": state_data["docker_image"],
"docker_parameters": list(state_data["docker_parameters"]),
"env": state_data["env"],
"secret_env": state_data["secret_env"],
"secret_volumes": [serialize_namedtuple(volume) for volume in state_data["secret_volumes"]],
"projected_sa_volumes": [serialize_namedtuple(volume) for volume in state_data["projected_sa_volumes"]],
"field_selector_env": state_data["field_selector_env"],
"extra_volumes": list(state_data["extra_volumes"]),
"node_selectors": state_data["node_selectors"],
"node_affinities": [serialize_namedtuple(affinity) for affinity in state_data["node_affinities"]],
"labels": state_data["labels"],
"annotations": state_data["annotations"],
"service_account_name": state_data["service_account_name"],
"ports": state_data["ports"],
}
)
try:
return json.dumps(
{
"command": state_data["command"],
"cpus": state_data["cpus"],
"mem": state_data["mem"],
"disk": state_data["disk"],
"cap_add": state_data["cap_add"],
"cap_drop": state_data["cap_drop"],
"constraints": list(state_data["constraints"]),
"docker_image": state_data["docker_image"],
"docker_parameters": list(state_data["docker_parameters"]),
"env": state_data["env"],
"secret_env": state_data["secret_env"],
"secret_volumes": [serialize_namedtuple(volume) for volume in state_data["secret_volumes"]],
"projected_sa_volumes": [
serialize_namedtuple(volume) for volume in state_data["projected_sa_volumes"]
],
"field_selector_env": state_data["field_selector_env"],
"extra_volumes": list(state_data["extra_volumes"]),
"node_selectors": state_data["node_selectors"],
"node_affinities": [serialize_namedtuple(affinity) for affinity in state_data["node_affinities"]],
"labels": state_data["labels"],
"annotations": state_data["annotations"],
"service_account_name": state_data["service_account_name"],
"ports": state_data["ports"],
}
)
except KeyError as e:
log.error(f"Missing key in state_data: {e}")
return None
except Exception as e:
log.error(f"Error serializing ActionCommandConfig to JSON: {e}")
return None


@dataclass
Expand Down
82 changes: 48 additions & 34 deletions tron/core/actionrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,26 @@ def state_data(self):
return state_data

@staticmethod
def to_json(state_data: dict) -> str:
def to_json(state_data: dict) -> Optional[str]:
"""Serialize the ActionRunAttempt instance to a JSON string."""
return json.dumps(
{
"command_config": ActionCommandConfig.to_json(state_data["command_config"]),
"start_time": state_data["start_time"].isoformat() if state_data["start_time"] else None,
"end_time": state_data["end_time"].isoformat() if state_data["end_time"] else None,
"rendered_command": state_data["rendered_command"],
"exit_status": state_data["exit_status"],
"mesos_task_id": state_data["mesos_task_id"],
"kubernetes_task_id": state_data["kubernetes_task_id"],
}
)
try:
return json.dumps(
{
"command_config": ActionCommandConfig.to_json(state_data["command_config"]),
"start_time": state_data["start_time"].isoformat() if state_data["start_time"] else None,
"end_time": state_data["end_time"].isoformat() if state_data["end_time"] else None,
"rendered_command": state_data["rendered_command"],
"exit_status": state_data["exit_status"],
"mesos_task_id": state_data["mesos_task_id"],
"kubernetes_task_id": state_data["kubernetes_task_id"],
}
)
except KeyError as e:
log.error(f"Missing key in state_data: {e}")
return None
except Exception as e:
log.error(f"Error serializing ActionRunAttempt to JSON: {e}")
return None

@classmethod
def from_state(cls, state_data):
Expand Down Expand Up @@ -731,35 +738,42 @@ def state_data(self):
}

@staticmethod
def to_json(state_data: dict) -> str:
def to_json(state_data: dict) -> Optional[str]:
"""Serialize the ActionRun instance to a JSON string."""
action_runner = state_data.get("action_runner")
if action_runner is None:
action_runner_json = NoActionRunnerFactory.to_json()
else:
action_runner_json = SubprocessActionRunnerFactory.to_json(action_runner)

return json.dumps(
{
"job_run_id": state_data["job_run_id"],
"action_name": state_data["action_name"],
"state": state_data["state"],
"original_command": state_data["original_command"],
"start_time": state_data["start_time"].isoformat() if state_data["start_time"] else None,
"end_time": state_data["end_time"].isoformat() if state_data["end_time"] else None,
"node_name": state_data["node_name"],
"exit_status": state_data["exit_status"],
"attempts": [ActionRunAttempt.to_json(attempt) for attempt in state_data["attempts"]],
"retries_remaining": state_data["retries_remaining"],
"retries_delay": state_data["retries_delay"],
"action_runner": action_runner_json,
"executor": state_data["executor"],
"trigger_downstreams": state_data["trigger_downstreams"],
"triggered_by": state_data["triggered_by"],
"on_upstream_rerun": state_data["on_upstream_rerun"],
"trigger_timeout_timestamp": state_data["trigger_timeout_timestamp"],
}
)
try:
return json.dumps(
{
"job_run_id": state_data["job_run_id"],
"action_name": state_data["action_name"],
"state": state_data["state"],
"original_command": state_data["original_command"],
"start_time": state_data["start_time"].isoformat() if state_data["start_time"] else None,
"end_time": state_data["end_time"].isoformat() if state_data["end_time"] else None,
"node_name": state_data["node_name"],
"exit_status": state_data["exit_status"],
"attempts": [ActionRunAttempt.to_json(attempt) for attempt in state_data["attempts"]],
"retries_remaining": state_data["retries_remaining"],
"retries_delay": state_data["retries_delay"],
"action_runner": action_runner_json,
"executor": state_data["executor"],
"trigger_downstreams": state_data["trigger_downstreams"],
"triggered_by": state_data["triggered_by"],
"on_upstream_rerun": state_data["on_upstream_rerun"],
"trigger_timeout_timestamp": state_data["trigger_timeout_timestamp"],
}
)
except KeyError as e:
log.error(f"Missing key in state_data: {e}")
return None
except Exception as e:
log.error(f"Error serializing ActionRun to JSON: {e}")
return None

def render_template(self, template):
"""Render our configured command using the command context."""
Expand Down
8 changes: 6 additions & 2 deletions tron/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,13 @@ def __init__(
log.info(f"{self} created")

@staticmethod
def to_json(state_data: dict) -> str:
def to_json(state_data: dict) -> Optional[str]:
"""Serialize the Job instance to a JSON string."""
return json.dumps(state_data)
try:
return json.dumps(state_data)
except Exception as e:
log.error(f"Error serializing Job to JSON: {e}")
return None

@classmethod
def from_config(
Expand Down
31 changes: 19 additions & 12 deletions tron/core/jobrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,26 @@ def __init__(
self.context = command_context.build_context(self, base_context)

@staticmethod
def to_json(state_data: dict) -> str:
def to_json(state_data: dict) -> Optional[str]:
"""Serialize the JobRun instance to a JSON string."""
return json.dumps(
{
"job_name": state_data["job_name"],
"run_num": state_data["run_num"],
"run_time": state_data["run_time"].isoformat() if state_data["run_time"] else None,
"node_name": state_data["node_name"],
"runs": [ActionRun.to_json(run) for run in state_data["runs"]],
"cleanup_run": ActionRun.to_json(state_data["cleanup_run"]) if state_data["cleanup_run"] else None,
"manual": state_data["manual"],
}
)
try:
return json.dumps(
{
"job_name": state_data["job_name"],
"run_num": state_data["run_num"],
"run_time": state_data["run_time"].isoformat() if state_data["run_time"] else None,
"node_name": state_data["node_name"],
"runs": [ActionRun.to_json(run) for run in state_data["runs"]],
"cleanup_run": ActionRun.to_json(state_data["cleanup_run"]) if state_data["cleanup_run"] else None,
"manual": state_data["manual"],
}
)
except KeyError as e:
log.error(f"Missing key in state_data: {e}")
return None
except Exception as e:
log.error(f"Error serializing JobRun to JSON: {e}")
return None

@property
def id(self):
Expand Down
32 changes: 17 additions & 15 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
Expand Down Expand Up @@ -161,11 +162,12 @@ def save(self, key_value_pairs) -> None:
self.save_queue[key] = (val, None)
else:
state_type = self.get_type_from_key(key)
json_val = self._serialize_item(state_type, val)
self.save_queue[key] = (
val,
json_val,
)
try:
json_val = self._serialize_item(state_type, val)
except Exception as e:
log.error(f"Failed to serialize JSON for key {key}: {e}")
json_val = None # Proceed without JSON if serialization fails
self.save_queue[key] = (val, json_val)
break

def _consume_save_queue(self):
Expand Down Expand Up @@ -201,7 +203,7 @@ def get_type_from_key(self, key: str) -> str:
return key.split()[0]

# TODO: TRON-2305 - In an ideal world, we wouldn't be passing around state/state_data dicts. It would be a lot nicer to have regular objects here
def _serialize_item(self, key: Literal[runstate.JOB_STATE, runstate.JOB_RUN_STATE], state: Dict[str, Any]) -> str: # type: ignore
def _serialize_item(self, key: Literal[runstate.JOB_STATE, runstate.JOB_RUN_STATE], state: Dict[str, Any]) -> Optional[str]: # type: ignore
if key == runstate.JOB_STATE:
return Job.to_json(state)
elif key == runstate.JOB_RUN_STATE:
Expand Down Expand Up @@ -239,11 +241,9 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:

pickled_val, json_val = value
num_partitions = math.ceil(len(pickled_val) / OBJECT_SIZE)
num_json_val_partitions = math.ceil(len(json_val) / OBJECT_SIZE)
num_json_val_partitions = math.ceil(len(json_val) / OBJECT_SIZE) if json_val else 0
items = []

# Use the maximum number of partitions (JSON can be larger
# than pickled value so this makes sure we save the entire item)
max_partitions = max(num_partitions, num_json_val_partitions)
for index in range(max_partitions):
item = {
Expand All @@ -263,17 +263,19 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
"num_partitions": {
"N": str(num_partitions),
},
"json_val": {
"S": json_val[index * OBJECT_SIZE : min(index * OBJECT_SIZE + OBJECT_SIZE, len(json_val))]
},
"num_json_val_partitions": {
"N": str(num_json_val_partitions),
},
},
"TableName": self.name,
},
}

if json_val:
item["Put"]["Item"]["json_val"] = {
"S": json_val[index * OBJECT_SIZE : min(index * OBJECT_SIZE + OBJECT_SIZE, len(json_val))]
}
item["Put"]["Item"]["num_json_val_partitions"] = {
"N": str(num_json_val_partitions),
}

count = 0
items.append(item)

Expand Down
3 changes: 2 additions & 1 deletion tron/utils/persistable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from abc import abstractmethod
from typing import Any
from typing import Dict
from typing import Optional


class Persistable(ABC):
@staticmethod
@abstractmethod
def to_json(state_data: Dict[Any, Any]) -> str:
def to_json(state_data: Dict[Any, Any]) -> Optional[str]:
pass

0 comments on commit 2c41653

Please sign in to comment.