Skip to content

Commit

Permalink
ShellJob: Fix RemoteData handling
Browse files Browse the repository at this point in the history
The `filenames` input was not taken into account for `RemoteData` input
nodes.
  • Loading branch information
sphuber committed Jan 18, 2025
1 parent 189df63 commit f2acef2
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 12 deletions.
36 changes: 26 additions & 10 deletions src/aiida_shell/calculations/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aiida.common.datastructures import CalcInfo, CodeInfo, FileCopyOperation
from aiida.common.folders import Folder
from aiida.engine import CalcJob, CalcJobProcessSpec
from aiida.orm import Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type
from aiida.orm import Computer, Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type
from aiida.parsers import Parser

from aiida_shell.data import EntryPointData, PickledData
Expand Down Expand Up @@ -281,9 +281,11 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
inputs = {}

nodes = inputs.get('nodes', {})
computer = inputs['code'].computer
filenames = (inputs.get('filenames', None) or Dict()).get_dict()
arguments = (inputs.get('arguments', None) or List()).get_list()
outputs = (inputs.get('outputs', None) or List()).get_list()
use_symlinks = inputs['metadata']['options']['use_symlinks']
filename_stdin = inputs['metadata']['options'].get('filename_stdin', None)
filename_stdout = inputs['metadata']['options'].get('output_filename', None)
default_retrieved_temporary = list(self.DEFAULT_RETRIEVED_TEMPORARY)
Expand All @@ -300,7 +302,10 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
if filename_stdin and filename_stdin in processed_arguments:
processed_arguments.remove(filename_stdin)

remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes(inputs)
remote_data_nodes = {key: node for key, node in nodes.items() if isinstance(node, RemoteData)}
remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes(
remote_data_nodes, filenames, computer, use_symlinks
)

code_info = CodeInfo()
code_info.code_uuid = inputs['code'].uuid
Expand Down Expand Up @@ -329,16 +334,22 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
return calc_info

@staticmethod
def handle_remote_data_nodes(inputs: dict[str, Data]) -> tuple[list[t.Any], list[t.Any]]:
"""Handle a ``RemoteData`` that was passed in the ``nodes`` input.
def handle_remote_data_nodes(
remote_data_nodes: dict[str, RemoteData], filenames: dict[str, str], computer: Computer, use_symlinks: bool
) -> tuple[list[t.Any], list[t.Any]]:
"""Handle all ``RemoteData`` nodes that were passed in the ``nodes`` input.
:param inputs: The inputs dictionary.
:param remote_data_nodes: The ``RemoteData`` input nodes.
:param filenames: A dictionary of explicit filenames to use for the ``nodes`` to be written to ``dirpath``.
:returns: A tuple of two lists, the ``remote_copy_list`` and the ``remote_symlink_list``.
"""
use_symlinks: bool = inputs['metadata']['options']['use_symlinks'] # type: ignore[index]
computer_uuid = inputs['code'].computer.uuid # type: ignore[union-attr]
remote_nodes = [node for node in inputs.get('nodes', {}).values() if isinstance(node, RemoteData)]
instructions = [(computer_uuid, f'{node.get_remote_path()}/*', '.') for node in remote_nodes]
instructions = []

for key, node in remote_data_nodes.items():
if key in filenames:
instructions.append((computer.uuid, node.get_remote_path(), filenames[key]))
else:
instructions.append((computer.uuid, f'{node.get_remote_path()}/*', '.'))

if use_symlinks:
return [], instructions
Expand Down Expand Up @@ -407,7 +418,10 @@ def process_arguments_and_nodes(
self.write_folder_data(node, dirpath, filename)
argument_interpolated = argument.format(**{placeholder: filename or placeholder})
elif isinstance(node, RemoteData):
self.handle_remote_data(node)
# Only the placeholder needs to be formatted. The content of the remote data itself is handled by the
# engine through the instructions created in ``handle_remote_data_nodes``.
filename = prepared_filenames[placeholder]
argument_interpolated = argument.format(**{placeholder: filename or placeholder})
else:
argument_interpolated = argument.format(**{placeholder: str(node.value)})

Expand Down Expand Up @@ -465,6 +479,8 @@ def prepare_filenames(self, nodes: dict[str, SinglefileData], filenames: dict[st
raise RuntimeError(
f'node `{key}` contains the file `{f}` which overlaps with a reserved output filename.'
)
elif isinstance(node, RemoteData):
filename = filenames.get(key, None)
else:
continue

Expand Down
35 changes: 34 additions & 1 deletion tests/calculations/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_nodes_folder_data(generate_calc_job, generate_code, tmp_path):

@pytest.mark.parametrize('use_symlinks', (True, False))
def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_localhost, use_symlinks):
"""Test the ``nodes`` input with ``RemoteData`` nodes ."""
"""Test the ``nodes`` input with ``RemoteData`` nodes."""
inputs = {
'code': generate_code(),
'arguments': [],
Expand All @@ -107,6 +107,39 @@ def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_loc
assert sorted(calc_info.remote_copy_list) == [(aiida_localhost.uuid, str(tmp_path / '*'), '.')]


def test_nodes_remote_data_filename(generate_calc_job, generate_code, tmp_path, aiida_localhost):
"""Test the ``nodes`` and ``filenames`` inputs with ``RemoteData`` nodes."""
remote_path_a = tmp_path / 'remote_a'
remote_path_b = tmp_path / 'remote_b'
remote_path_a.mkdir()
remote_path_b.mkdir()
(remote_path_a / 'file_a.txt').write_text('content a')
(remote_path_b / 'file_b.txt').write_text('content b')
remote_data_a = RemoteData(remote_path=str(remote_path_a.absolute()), computer=aiida_localhost)
remote_data_b = RemoteData(remote_path=str(remote_path_b.absolute()), computer=aiida_localhost)

inputs = {
'code': generate_code(),
'arguments': ['{remote_a}'],
'nodes': {
'remote_a': remote_data_a,
'remote_b': remote_data_b,
},
'filenames': {'remote_a': 'target_remote'},
}
dirpath, calc_info = generate_calc_job('core.shell', inputs)

code_info = calc_info.codes_info[0]
assert code_info.cmdline_params == ['target_remote']

assert calc_info.remote_symlink_list == []
assert sorted(calc_info.remote_copy_list) == [
(aiida_localhost.uuid, str(remote_path_a), 'target_remote'),
(aiida_localhost.uuid, str(remote_path_b / '*'), '.'),
]
assert sorted(p.name for p in dirpath.iterdir()) == []


def test_nodes_base_types(generate_calc_job, generate_code):
"""Test the ``nodes`` input with ``BaseType`` nodes ."""
inputs = {
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def factory(entry_point_name='core.shell', store_provenance=False, filepath_retr


@pytest.fixture
def generate_calc_job(tmp_path):
def generate_calc_job(tmp_path_factory):
"""Create a :class:`aiida.engine.CalcJob` instance with the given inputs.
The fixture will call ``prepare_for_submission`` and return a tuple of the temporary folder that was passed to it,
Expand All @@ -81,6 +81,7 @@ def factory(
which ensures that all input files are written, including those by the scheduler plugin, such as the
submission script.
"""
tmp_path = tmp_path_factory.mktemp('calc_job_submit_dir')
manager = get_manager()
runner = manager.get_runner()

Expand Down
22 changes: 22 additions & 0 deletions tests/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,28 @@ def test_nodes_remote_data(tmp_path, aiida_localhost, use_symlinks):
assert (dirpath_working / 'filled' / 'file_b.txt').read_text() == 'content b'


def test_nodes_remote_data_filename(tmp_path_factory, aiida_localhost):
"""Test copying contents of a ``RemoteData`` to specific subdirectory."""
dirpath_remote = tmp_path_factory.mktemp('remote')
dirpath_source = dirpath_remote / 'source'
dirpath_source.mkdir()
(dirpath_source / 'file.txt').touch()
remote_data = RemoteData(remote_path=str(dirpath_remote), computer=aiida_localhost)

results, node = launch_shell_job(
'echo',
arguments=['{remote}'],
nodes={'remote': remote_data},
filenames={'remote': 'sub_directory'},
)
assert node.is_finished_ok
assert results['stdout'].get_content().strip() == 'sub_directory'
dirpath_working = pathlib.Path(node.outputs.remote_folder.get_remote_path())
assert (dirpath_working / 'sub_directory').is_dir()
assert (dirpath_working / 'sub_directory' / 'source').is_dir()
assert (dirpath_working / 'sub_directory' / 'source' / 'file.txt').is_file()


def test_nodes_base_types():
"""Test a shellfunction that specifies positional CLI arguments that are interpolated by the ``kwargs``."""
nodes = {
Expand Down

0 comments on commit f2acef2

Please sign in to comment.