Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shared_memory to task with extended resources #3096

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@
CACHE_KEY_METADATA = "cache-key-metadata"

SERIALIZATION_FORMAT = "serialization-format"

# Shared memory mount name and path
SHARED_MEMORY_MOUNT_NAME = "flyte-shared-memory"
SHARED_MEMORY_MOUNT_PATH = "/dev/shm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insecure hardcoded temporary file path

Consider using a more secure temporary file location instead of hardcoding '/dev/shm'. The shared memory directory could potentially be accessed by other processes on the system. Consider using 'tempfile.gettempdir()' to get a secure temporary directory location.

Code suggestion
Check the AI-generated fix before applying
Suggested change
SHARED_MEMORY_MOUNT_PATH = "/dev/shm"
import tempfile
SHARED_MEMORY_MOUNT_PATH = tempfile.gettempdir()

Code Review Run #3d3e41


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

10 changes: 8 additions & 2 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import datetime
import typing
from typing import Any, Dict, List, Optional, Union
from typing import Literal as L

from flyteidl.core import tasks_pb2

from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.resources import Resources, construct_extended_resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.loggers import logger
Expand Down Expand Up @@ -193,6 +194,7 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
shared_memory: Optional[Union[L[True], str]] = None,
pod_template: Optional[PodTemplate] = None,
*args,
**kwargs,
Expand Down Expand Up @@ -240,7 +242,11 @@ def with_overrides(

if accelerator is not None:
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if shared_memory is not None:
assert_not_promise(shared_memory, "shared_memory")

self._extended_resources = construct_extended_resources(accelerator=accelerator, shared_memory=shared_memory)

self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

Expand Down
12 changes: 7 additions & 5 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import Literal as L

from flyteidl.core import tasks_pb2

Expand All @@ -13,7 +14,7 @@
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.resources import Resources, ResourceSpec, construct_extended_resources
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
shared_memory: Optional[Union[L[True], str]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -78,6 +80,8 @@ def __init__(
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal
to the allocated memory. If str, then the shared memory is set to that size.
"""
sec_ctx = None
if secret_requests:
Expand Down Expand Up @@ -128,6 +132,7 @@ def __init__(

self.pod_template = pod_template
self.accelerator = accelerator
self.shared_memory = shared_memory

@property
def task_resolver(self) -> TaskResolverMixin:
Expand Down Expand Up @@ -250,10 +255,7 @@ def get_extended_resources(self, settings: SerializationSettings) -> Optional[ta
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
if self.accelerator is None:
return None

return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl())
return construct_extended_resources(accelerator=self.accelerator, shared_memory=self.shared_memory)


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
Expand Down
33 changes: 33 additions & 0 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from dataclasses import dataclass, fields
from typing import List, Optional, Union
from typing import Literal as L

from flyteidl.core import tasks_pb2
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.constants import SHARED_MEMORY_MOUNT_NAME, SHARED_MEMORY_MOUNT_PATH
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.models import task as task_models


Expand Down Expand Up @@ -102,6 +106,35 @@ def convert_resources_to_resource_model(
return task_models.Resources(requests=request_entries, limits=limit_entries)


def construct_extended_resources(
*,
accelerator: Optional[BaseAccelerator] = None,
shared_memory: Optional[Union[L[True], str]] = None,
) -> Optional[tasks_pb2.ExtendedResources]:
"""Convert public extended resources to idl.

:param accelerator: The accelerator to use for this task.
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal
to the allocated memory. If str, then the shared memory is set to that size.
"""
kwargs = {}
if accelerator is not None:
kwargs["gpu_accelerator"] = accelerator.to_flyte_idl()
if isinstance(shared_memory, str) or shared_memory is True:
if shared_memory is True:
shared_memory = None
kwargs["shared_memory"] = tasks_pb2.SharedMemory(
mount_name=SHARED_MEMORY_MOUNT_NAME,
mount_path=SHARED_MEMORY_MOUNT_PATH,
size_limit=shared_memory,
)

if not kwargs:
return None

return tasks_pb2.ExtendedResources(**kwargs)


def pod_spec_from_resources(
primary_container_name: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter name and type change impact

The parameter name change from k8s_pod_name to primary_container_name with an optional type may require updates in calling code.

Code suggestion
Check the AI-generated fix before applying
Suggested change
primary_container_name: Optional[str] = None,
k8s_pod_name: str = None,
primary_container_name: Optional[str] = None, # New parameter, defaults to k8s_pod_name if set

Code Review Run #31b042


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

requests: Optional[Resources] = None,
Expand Down
7 changes: 7 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload
from typing import Literal as L

from typing_extensions import ParamSpec # type: ignore

Expand Down Expand Up @@ -128,6 +129,7 @@ def task(
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
shared_memory: Optional[Union[L[True], str]] = None,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -167,6 +169,7 @@ def task(
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
shared_memory: Optional[Union[L[True], str]] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -211,6 +214,7 @@ def task(
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
shared_memory: Optional[Union[L[True], str]] = None,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Expand Down Expand Up @@ -341,6 +345,8 @@ def launch_dynamically():
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal
to the allocated memory. If int, then the shared memory is set to that size.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -390,6 +396,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
pod_template_name=pod_template_name,
accelerator=accelerator,
pickle_untyped=pickle_untyped,
shared_memory=shared_memory,
)
update_wrapper(task_instance, decorated_fn)
return task_instance
Expand Down
20 changes: 20 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,26 @@ def wf(x: typing.List[int]):
assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu"


def test_serialization_extended_resources_shared_memory(serialization_settings):
@task(
shared_memory="2Gi"
)
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1)

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

od = OrderedDict()
get_serializable(od, serialization_settings, wf)
task_spec = od[arraynode_maptask]

assert task_spec.template.extended_resources.shared_memory.size_limit == "2Gi"


def test_supported_node_type():
@task
def test_task():
Expand Down
23 changes: 23 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,29 @@ def my_wf() -> str:
assert not accelerator.HasField("unpartitioned")


def test_override_shared_memory():
@task(shared_memory=True)
def bar() -> str:
return "hello"

@workflow
def my_wf() -> str:
return bar().with_overrides(shared_memory="128Mi")

serialization_settings = flytekit.configuration.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].task_node.overrides is not None
assert wf_spec.template.nodes[0].task_node.overrides.extended_resources is not None
shared_memory = wf_spec.template.nodes[0].task_node.overrides.extended_resources.shared_memory


def test_cache_override_values():
@task
def t1(a: str) -> str:
Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from flytekit.core.resources import (
pod_spec_from_resources,
convert_resources_to_resource_model,
construct_extended_resources,
)
from flytekit.extras.accelerators import T4

_ResourceName = _task_models.Resources.ResourceName

Expand Down Expand Up @@ -155,3 +157,18 @@ def test_pod_spec_from_resources_requests_set():
)
pod_spec = pod_spec_from_resources(primary_container_name=primary_container_name, requests=requests, limits=limits)
assert expected_pod_spec == pod_spec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider type conversion before assertion comparison

The assertion expected_pod_spec == pod_spec is comparing two different types - V1PodSpec vs dict. Consider using V1PodSpec(**pod_spec) to convert the dictionary to a V1PodSpec object before comparison.

Code suggestion
Check the AI-generated fix before applying
Suggested change
assert expected_pod_spec == pod_spec
assert expected_pod_spec == V1PodSpec(**pod_spec)

Code Review Run #31b042


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged



@pytest.mark.parametrize("shared_memory", [None, False])
def test_construct_extended_resources_shared_memory_none(shared_memory):
resources = construct_extended_resources(shared_memory=shared_memory)
Comment on lines +162 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider consolidating redundant test cases

Consider consolidating the test cases for None and False into a single test case since they produce the same behavior. Both values result in resources being None.

Code suggestion
Check the AI-generated fix before applying
Suggested change
@pytest.mark.parametrize("shared_memory", [None, False])
def test_construct_extended_resources_shared_memory_none(shared_memory):
resources = construct_extended_resources(shared_memory=shared_memory)
def test_construct_extended_resources_shared_memory_none():
resources = construct_extended_resources(shared_memory=None)

Code Review Run #3d3e41


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

assert resources is None


@pytest.mark.parametrize("shared_memory, expected_size_limit", [
("2Gi", "2Gi"),
(True, ""),
])
def test_construct_extended_resources_shared_memory(shared_memory, expected_size_limit):
resources = construct_extended_resources(shared_memory=shared_memory)
assert resources.shared_memory.size_limit == expected_size_limit
4 changes: 3 additions & 1 deletion tests/flytekit/unit/models/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flyteidl.core.tasks_pb2 import ExtendedResources, TaskMetadata
from google.protobuf import text_format

from flytekit.core.resources import construct_extended_resources
import flytekit.models.interface as interface_models
import flytekit.models.literals as literal_models
from flytekit import Description, Documentation, SourceCode
Expand Down Expand Up @@ -110,7 +111,7 @@ def test_task_template(in_tuple):
{"d": "e"},
),
config={"a": "b"},
extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()),
extended_resources=construct_extended_resources(accelerator=T4, shared_memory="2Gi"),
)
assert obj.id.resource_type == identifier.ResourceType.TASK
assert obj.id.project == "project"
Expand All @@ -130,6 +131,7 @@ def test_task_template(in_tuple):
assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4"
assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned")
assert not obj.extended_resources.gpu_accelerator.HasField("partition_size")
assert obj.extended_resources.shared_memory.size_limit == "2Gi"


def test_task_spec():
Expand Down
Loading