Skip to content

Commit

Permalink
Serializers with streaming io (#41)
Browse files Browse the repository at this point in the history
* refactor: modify serializer protocol and implementations to work with streaming io

* refactor: adapt local runtime to new serializer protocol

* refactor: adapt cli to refactored local runtime (save progress)

* refactor: adapt cli to refactored local runtime

* refactor: adapt yaml serializer example

* docs: update serializer performance section

* test: manually test memory allocation

* test: add serializer memory footprint tests

* chore: upgrade pytest to ^6.2 and commit poetry.lock

* docs: remove memory footprint as a limitation of the local runtime

* fix: do not make assertions about json indentation on windows

* test: cover custom exceptions in runtime.cli.invoke_with_locations
  • Loading branch information
larribas authored Sep 24, 2021
1 parent 1f22024 commit f72db46
Show file tree
Hide file tree
Showing 40 changed files with 1,887 additions and 337 deletions.
1 change: 0 additions & 1 deletion .github/workflows/continuous-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,3 @@ jobs:
fail_ci_if_error: true
env_vars: OS,PYTHON
verbose: true

3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,3 @@ dmypy.json

# Pyre type checker
.pyre/

# Poetry
poetry.lock
12 changes: 0 additions & 12 deletions dagger/dag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from dagger.input import FromNodeOutput, FromParam
from dagger.input import validate_name as validate_input_name
from dagger.output import validate_name as validate_output_name
from dagger.serializer import SerializationError
from dagger.task import SupportedInputs as SupportedTaskInputs
from dagger.task import Task

Expand Down Expand Up @@ -208,9 +207,6 @@ def validate_parameters(
------
ValueError
If the set of parameters does not contain all the required inputs.
SerializationError
If the value provided for a parameter is not compatible with the serializer defined for that input.
"""
missing_params = inputs.keys() - params.keys()
if missing_params:
Expand All @@ -224,14 +220,6 @@ def validate_parameters(
f"The following parameters were supplied to this DAG, but are not necessary: {sorted(list(superfluous_params))}"
)

for input_name in inputs:
try:
inputs[input_name].serializer.serialize(params[input_name])
except SerializationError as e:
raise SerializationError(
f"The value supplied for input '{input_name}' is not compatible with the serializer defined for that input ({inputs[input_name].serializer}): {e}"
)


def _validate_node_name(name: str):
if not VALID_NAME.match(name):
Expand Down
45 changes: 29 additions & 16 deletions dagger/runtime/cli/invoke_with_locations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Command-line Interface to run DAGs or Tasks taking their inputs from files and storing their outputs into files."""
import tempfile
from typing import Any, Iterable, List, Mapping

import dagger.runtime.local as local
Expand Down Expand Up @@ -43,6 +44,9 @@ def invoke_with_locations(
TypeError
When any of the outputs cannot be obtained from the return value of their node
OSError
When there is a problem with the operating system's permissions to access the supplied input/output locations.
SerializationError
When some of the outputs cannot be serialized with the specified Serializer
"""
Expand All @@ -55,14 +59,24 @@ def invoke_with_locations(

params = _deserialized_params(nested_node, input_locations)

outputs = local.invoke(nested_node.node, params)

for output_name in output_locations:
store_output_in_location(
output_location=output_locations[output_name],
output_value=outputs[output_name],
with tempfile.TemporaryDirectory() as tmp:
outputs = local.invoke(
nested_node.node,
params=params,
outputs=local.StoreSerializedOutputsInPath(tmp),
)

for output_name in output_locations:
try:
store_output_in_location(
output_location=output_locations[output_name],
output_value=outputs[output_name],
)
except (OSError, FileExistsError, IsADirectoryError, PermissionError) as e:
raise OSError(
f"When storing output '{output_name}', we got the following error: {str(e)}"
) from e


def _validate_inputs(
input_names: Iterable[str],
Expand Down Expand Up @@ -95,15 +109,14 @@ def _deserialized_params(
"""Retrieve and deserialize all the parameters expected by a Node."""
params = {}
for input_name in input_locations:
input_value = retrieve_input_from_location(input_locations[input_name])
input_type = nested_node.node.inputs[input_name]

if isinstance(input_value, local.PartitionedOutput):
params[input_name] = [
input_type.serializer.deserialize(partition)
for partition in input_value
]
else:
params[input_name] = input_type.serializer.deserialize(input_value)
try:
params[input_name] = retrieve_input_from_location(
input_location=input_locations[input_name],
serializer=nested_node.node.inputs[input_name].serializer,
)
except (FileNotFoundError, PermissionError) as e:
raise OSError(
f"When retrieving input '{input_name}' from the provided location, we got the following error: {str(e)}"
) from e

return params
56 changes: 38 additions & 18 deletions dagger/runtime/cli/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@

import json
import os
from typing import Any

from dagger.runtime.local import NodeOutput, PartitionedOutput
from dagger.serializer import Serializer

PARTITION_MANIFEST_FILENAME = "partitions.json"


def retrieve_input_from_location(input_location: str) -> NodeOutput:
def retrieve_input_from_location(
input_location: str,
serializer: Serializer,
) -> Any:
"""
Given an input location, retrieve the contents of the file/directory it points to.
Expand All @@ -24,10 +29,13 @@ def retrieve_input_from_location(input_location: str) -> NodeOutput:
and concatenate all existing partitions based on the lexicographical order
of their filenames.
serializer
The serializer implementation to use to deserialize the input file.
Returns
-------
The serialized version of the input. If the input is partitioned, it returns a list of serialized partitions.
The original value of the input. If the input is partitioned, it returns an iterable of values.
Raises
Expand All @@ -47,29 +55,34 @@ def retrieve_input_from_location(input_location: str) -> NodeOutput:
]
sorted_partition_filenames = sorted(partition_filenames, key=int)

def load_lazily(partition_filename: str):
with open(os.path.join(input_location, partition_filename), "rb") as f:
return f.read()
def load(partition_filename: str) -> Any:
with open(os.path.join(input_location, partition_filename), "rb") as reader:
return serializer.deserialize(reader)

return PartitionedOutput(map(load_lazily, sorted_partition_filenames))
return [load(fname) for fname in sorted_partition_filenames]

else:
with open(input_location, "rb") as f:
return f.read()
with open(input_location, "rb") as reader:
return serializer.deserialize(reader)


def store_output_in_location(output_location: str, output_value: NodeOutput):
def store_output_in_location(
output_location: str,
output_value: NodeOutput,
):
"""
Store a serialized output into the specified location.
It uses os.rename(): https://docs.python.org/3/library/os.html#os.rename
Parameters
----------
output_location
A pointer to a path (e.g. "/my/filesystem/file.txt").
The path must not exist previously.
output_value
The serialized representation of a node output.
A NodeOutput, pointing to the file that contains the serialized version of the output value.
It may be partitioned. If it is, we will treat the output_location as a directory
and dump each partition separately, together with a file named "partitions.json"
containing a json-serialized list with all the partitions.
Expand All @@ -79,11 +92,14 @@ def store_output_in_location(output_location: str, output_value: NodeOutput):
Raises
------
IsADirectoryError
If the output_location is a directory.
OSError
If the output location is a non-empty directory, in Unix
FileExistsError
If the output_location already exists.
If the output location already exists, in Windows
IsADirectoryError
If the output location exists and it is an empty directory, in Unix.
PermissionError
If the current execution context doesn't have enough permissions to read the file.
Expand All @@ -92,14 +108,18 @@ def store_output_in_location(output_location: str, output_value: NodeOutput):
os.mkdir(output_location)
partition_filenames = []

for i, partition in enumerate(output_value):
for i, src in enumerate(output_value):
partition_filename = str(i)
os.rename(
src.filename,
os.path.join(
output_location,
partition_filename,
),
)
partition_filenames.append(partition_filename)
with open(os.path.join(output_location, partition_filename), "wb") as f:
f.write(partition)

with open(os.path.join(output_location, PARTITION_MANIFEST_FILENAME), "w") as p:
json.dump(partition_filenames, p)
else:
with open(output_location, "wb") as f:
f.write(output_value)
os.rename(output_value.filename, output_location)
7 changes: 6 additions & 1 deletion dagger/runtime/local/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Run DAGs or nodes in memory."""

from dagger.runtime.local.dag import invoke # noqa
from dagger.runtime.local.invoke import ( # noqa
ReturnDeserializedOutputs,
StoreSerializedOutputsInPath,
invoke,
)
from dagger.runtime.local.types import ( # noqa
NodeOutput,
NodeOutputs,
OutputFile,
PartitionedOutput,
)
82 changes: 34 additions & 48 deletions dagger/runtime/local/dag.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Run a DAG in memory."""
import itertools
import os
from typing import Any, Dict, Iterable, Mapping, Union

from dagger.dag import DAG, Node, validate_parameters
from dagger.input import FromNodeOutput, FromParam
from dagger.runtime.local.task import _invoke_task
from dagger.runtime.local.output import load
from dagger.runtime.local.task import invoke_task
from dagger.runtime.local.types import (
NodeExecutions,
NodeOutput,
Expand All @@ -16,49 +18,24 @@
from dagger.task import Task


def invoke(
def invoke_node(
node: Union[DAG, Task],
params: Mapping[str, Any] = None,
params: Mapping[str, Any],
output_path: str,
) -> Mapping[str, NodeOutput]:
"""
Invoke a node with a series of parameters.
Parameters
----------
node
Node to execute
params
Inputs to the task, indexed by input/parameter name.
Returns
-------
Serialized outputs of the task, indexed by output name.
Raises
------
ValueError
When any required parameters are missing
TypeError
When any of the outputs cannot be obtained from the return value of the task's function
SerializationError
When some of the outputs cannot be serialized with the specified Serializer
"""
"""Invoke a Node locally with the specified parameters and dump the serialized outputs on the path provided."""
if isinstance(node, DAG):
return _invoke_dag(node, params=params)
return invoke_dag(node, output_path=output_path, params=params)
else:
return _invoke_task(node, params=params)
return invoke_task(node, output_path=output_path, params=params)


def _invoke_dag(
def invoke_dag(
dag: DAG,
params: Mapping[str, Any] = None,
params: Mapping[str, Any],
output_path: str,
) -> NodeOutputs:
params = params or {}
"""Invoke a DAG locally with the specified parameters and dump the serialized outputs on the path provided."""
validate_parameters(dag.inputs, params)

outputs: Dict[str, NodeExecutions] = {}
Expand All @@ -68,16 +45,23 @@ def _invoke_dag(
node = dag.nodes[node_name]

try:
outputs[node_name] = PartitionedOutput(
[
invoke(node, params=p)
for p in _node_param_partitions(
node=node,
params=params,
outputs=outputs,
)
]
)
partitions = []

for i, p in enumerate(
_node_param_partitions(
node=node,
params=params,
outputs=outputs,
)
):
node_output_path = os.path.join(output_path, "nodes", node_name, str(i))
os.makedirs(node_output_path)

partitions.append(
invoke_node(node, params=p, output_path=node_output_path)
)

outputs[node_name] = PartitionedOutput(partitions)

if not node.partition_by_input:
outputs[node_name] = next(outputs[node_name])
Expand Down Expand Up @@ -155,6 +139,8 @@ def _node_param_from_output(
node_output: NodeOutput,
) -> Union[Any, PartitionedOutput[Any]]:
if isinstance(node_output, PartitionedOutput):
return PartitionedOutput(map(lambda v: serializer.deserialize(v), node_output))
return PartitionedOutput(
map(lambda n: load(filename=n.filename, serializer=serializer), node_output)
)
else:
return serializer.deserialize(node_output)
return load(filename=node_output.filename, serializer=serializer)
Loading

0 comments on commit f72db46

Please sign in to comment.