Skip to content

Commit

Permalink
refactor: handle additional args with kward
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 21, 2024
1 parent 3d748aa commit eed85f3
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions aipipeline/prediction/download_crop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,24 @@ def run_pipeline(argv=None):

parser = argparse.ArgumentParser(description="Download and crop unknown images.")
parser.add_argument("--config", required=True, help="Config file path")
parser.add_argument("--labels", required=True, help="Comma separated list of labels to download")
parser.add_argument("--labels", required=False, help="Comma separated list of labels to download")
parser.add_argument("--download-args", required=False, help="Additional arguments for download")
parser.add_argument("--download-dir", required=True, help="Directory to download images")
parser.add_argument("--download-dir", required=False, help="Directory to download images")
parser.add_argument("--skip-clean", required=False, default=False,
help="Skip cleaning of previously downloaded data")
args, beam_args = parser.parse_known_args(argv)
options = PipelineOptions(beam_args)
conf_files, config_dict = setup_config(args.config, silent=True)

if not os.path.exists(args.download_dir):
os.makedirs(args.download_dir)
if args.download_dir:
download_path = Path(args.download_dir)
else:
download_path = Path(config_dict["data"]["processed_path"])

# Override the config
config_dict["data"]["labels"] = args.labels
if args.download_args:
config_dict["data"]["download_args"] = [args.download_args]
if args.labels:
config_dict["data"]["labels"] = args.labels

labels = args.labels.split(",")
download_args = config_dict["data"]["download_args"]
labels = extract_labels_config(config_dict)

# Print the new config
logger.info("Configuration:")
Expand All @@ -67,33 +66,42 @@ def run_pipeline(argv=None):
# Make sure the download directory is contained in the config docker mounts
if "docker" in config_dict and "bind_volumes" in config_dict["docker"]:
for mount in config_dict["docker"]["bind_volumes"]:
if mount in args.download_dir:
if mount in download_path.as_posix():
break
else:
raise ValueError(f"Download directory {args.download_dir} not in docker mounts")

# Make sure the download directory is not a child of the processed directory - this is a safety check
# Otherwise, the processed data could be deleted
bind_volumes = config_dict["docker"]["bind_volumes"].keys()
# Check if the download directory is a child of any of the bind volumes in the docker config
# Check if the download directory is a child of any bind volumes in the docker config
found = False
for volume in bind_volumes:
if Path(args.download_dir).is_relative_to(volume):
if download_path.is_relative_to(volume):
found = True
break
if not found:
raise ValueError(
f"Download directory {args.download_dir} is not a child of any of the bind volumes in the docker config")

if not args.skip_clean:
clean(args.download_dir)
clean(download_path.as_posix())

kwargs = {}
if args.download_args:
kwargs["additional_args"] = args.download_args
if args.download_dir:
kwargs["download_dir"] = args.download_dir
processed_dir = args.download_dir
else:
processed_dir = config_dict["data"]["processed_path"]

with beam.Pipeline(options=options) as p:
(
p
| "Start download" >> beam.Create([labels])
| "Download labeled data" >> beam.Map(download, conf_files=conf_files, config_dict=config_dict, additional_args=download_args, download_dir=args.download_dir)
| "Crop ROI" >> beam.Map(crop_rois_voc, config_dict=config_dict, processed_dir=args.download_dir)
| "Download labeled data" >> beam.Map(download, conf_files=conf_files, config_dict=config_dict, **kwargs)
| "Crop ROI" >> beam.Map(crop_rois_voc, config_dict=config_dict, processed_dir=processed_dir)
| "Log results" >> beam.Map(logger.info)
)

Expand Down

0 comments on commit eed85f3

Please sign in to comment.