From c4466d14b8c92be12c108d863e542c65d94fac20 Mon Sep 17 00:00:00 2001 From: Charitarth Chugh <37895518+charitarthchugh@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:01:12 -0500 Subject: [PATCH] feat(data): start building dataloader --- src/lightningsparseinst/data/datamodule.py | 88 ++++++++++++++++------ 1 file changed, 65 insertions(+), 23 deletions(-) diff --git a/src/lightningsparseinst/data/datamodule.py b/src/lightningsparseinst/data/datamodule.py index 3c7d19b..114a5e1 100644 --- a/src/lightningsparseinst/data/datamodule.py +++ b/src/lightningsparseinst/data/datamodule.py @@ -3,9 +3,12 @@ from pathlib import Path from typing import List, Mapping, Optional +import fiftyone as fo import hydra -import lightning.pytorch as pl +import lightning as L import omegaconf +from albumentations import Compose +from fiftyone import ViewField from omegaconf import DictConfig from torch.utils.data import DataLoader, Dataset from torch.utils.data.dataloader import default_collate @@ -14,6 +17,8 @@ from nn_core.common import PROJECT_ROOT from nn_core.nn_types import Split +from lightningsparseinst.data.dataset import SegmentationDataset + pylogger = logging.getLogger(__name__) @@ -76,9 +81,7 @@ def load(src_path: Path) -> "MetaData": key, value = line.strip().split("\t") class_vocab[key] = value - return MetaData( - class_vocab=class_vocab, - ) + return MetaData(class_vocab=class_vocab) def __repr__(self) -> str: attributes = ",\n ".join([f"{key}={value}" for key, value in self.__dict__.items()]) @@ -99,29 +102,31 @@ def collate_fn(samples: List, split: Split, metadata: MetaData): return default_collate(samples) -class MyDataModule(pl.LightningDataModule): +class DataModule(L.LightningDataModule): def __init__( self, dataset: DictConfig, num_workers: DictConfig, batch_size: DictConfig, + split_names: DictConfig, accelerator: str, # example - val_images_fixed_idxs: List[int], ): super().__init__() self.dataset = dataset self.num_workers = num_workers self.batch_size = batch_size + self.split_names = split_names # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus self.pin_memory: bool = accelerator is not None and str(accelerator) == "gpu" + self.fiftyone_dataset: Optional[fo.Dataset] = None + self.classes: Optional[List[str] | str] = None + self.train_dataset: Optional[Dataset] = None self.val_dataset: Optional[Dataset] = None self.test_dataset: Optional[Dataset] = None - - # example - self.val_images_fixed_idxs: List[int] = val_images_fixed_idxs + self.transform: Optional[Compose] = None @cached_property def metadata(self) -> MetaData: @@ -136,25 +141,62 @@ def metadata(self) -> MetaData: if self.train_dataset is None: self.setup(stage="fit") - return MetaData(class_vocab={i: name for i, name in enumerate(self.train_dataset.features["y"].names)}) + return MetaData(class_vocab=self.train_dataset.labels_map_rev) def prepare_data(self) -> None: - # download only + # download only\ pass def setup(self, stage: Optional[str] = None): - self.transform = hydra.utils.instantiate(self.dataset.transforms) - - self.hf_datasets = hydra.utils.instantiate(self.dataset) - self.hf_datasets.set_transform(self.transform) - - # Here you should instantiate your dataset, you may also split the train into train and validation if needed. + self.fiftyone_dataset = fo.load_dataset(self.dataset.ref) + self.fiftyone_dataset.compute_metadata() + + self.transform = hydra.utils.instantiate(self.dataset.transform) + # Label filtering logic + self.classes = self.dataset.classes if "classes" in self.dataset.keys() else None + if self.classes: + if isinstance(self.classes, list): + self.fiftyone_dataset = self.fiftyone_dataset.filter_labels( + f"{self.dataset.gt_field}.{self.dataset.detection_field}", ViewField("label").is_in(self.classes) + ) + elif isinstance(self.classes, str): + # regex case + self.fiftyone_dataset = self.fiftyone_dataset.filter_labels( + f"{self.dataset.gt_field}.{self.dataset.detection_field}", ViewField("label").re_match(self.classes) + ) + else: + self.classes = self.fiftyone_dataset.distinct( + f"{self.dataset.gt_field}.{self.dataset.detection_field}.label" + ) + # self.hf_datasets = hydra.utils.instantiate(self.dataset) + # self.hf_datasets.set_transform(self.transform) + # + # # Here you should instantiate your dataset, you may also split the train into train and validation if needed. if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_dataset is None): - self.train_dataset = self.hf_datasets["train"] - self.val_dataset = self.hf_datasets["val"] - + self.train_dataset = SegmentationDataset( + self.fiftyone_dataset, + split=self.split_names["train"], + gt_field=self.dataset.gt_field, + detection_field=self.dataset.detection_field, + transform=self.transform, + max_num_instances_per_image=self.dataset.max_num_instances_per_image, + ) + self.val_dataset = SegmentationDataset( + self.fiftyone_dataset, + split=self.split_names["validation"], + gt_field=self.dataset.gt_field, + detection_field=self.dataset.detection_field, + max_num_instances_per_image=self.dataset.max_num_instances_per_image, + ) + # if stage is None or stage == "test": - self.test_dataset = self.hf_datasets["test"] + self.test_dataset = SegmentationDataset( + self.fiftyone_dataset, + split=self.split_names["test"], + gt_field=self.dataset.gt_field, + detection_field=self.dataset.detection_field, + max_num_instances_per_image=self.dataset.max_num_instances_per_image, + ) def train_dataloader(self) -> DataLoader: return DataLoader( @@ -190,14 +232,14 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(" f"{self.dataset=}, " f"{self.num_workers=}, " f"{self.batch_size=})" -@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") +@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default", version_base="1.1") def main(cfg: omegaconf.DictConfig) -> None: """Debug main to quickly develop the DataModule. Args: cfg: the hydra configuration """ - m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) + m: L.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) m.metadata m.setup()