diff --git a/configs/downstream/dev.yaml b/configs/downstream/dev.yaml index 99ab406..360fb82 100644 --- a/configs/downstream/dev.yaml +++ b/configs/downstream/dev.yaml @@ -28,9 +28,9 @@ model: - name: gifflar hidden_dim: 1024 batch_size: 33 - num_layers: 10 + num_layers: 5 epochs: 100 learning_rate: 0.001 optimizer: Adam - suffix: _1024_12 + suffix: _1024_5 diff --git a/gifflar/data.py b/gifflar/data.py index ca4e707..aa4049d 100644 --- a/gifflar/data.py +++ b/gifflar/data.py @@ -433,7 +433,8 @@ def process_(self, data, path_idx: int = 0) -> None: if self.pre_filter is not None: data = [d for d in data if self.pre_filter(d)] if self.pre_transform is not None: - data = [self.pre_transform(d) for d in data] + # data = [self.pre_transform(d) for d in data] + data = self.pre_transform(data) torch.save((data, self.dataset_args), self.processed_paths[path_idx]) diff --git a/gifflar/pretransforms.py b/gifflar/pretransforms.py index b63363b..a6a2bda 100644 --- a/gifflar/pretransforms.py +++ b/gifflar/pretransforms.py @@ -251,13 +251,22 @@ def __call__(self, data): class TQDMCompose(Compose): def forward(self, data: Union[Data, HeteroData]): - with tqdm(total=len(self.transforms), desc="Transforms") as t_bar: - for transform in self.transforms: - if isinstance(data, (list, tuple)): - data = transform(data) - else: - data = [transform(d) for d in tqdm(data, total=len(data), desc="Samples", leave=False)] - t_bar.update(1) + # print(self.transforms) + # print(len(self.transforms)) + # print(len(data)) + # with tqdm(total=len(self.transforms), desc="Transforms") as t_bar: + for transform in tqdm(self.transforms, desc=f"Transform"): + if not isinstance(data, (list, tuple)): + data = transform(data) + else: + # data = [transform(d) for d in data] + t_data = [] + for d in tqdm(data, leave=False): + t_data.append(transform(d)) + data = t_data + # s_bar.update(1) + # data = [transform(d) for d in tqdm(data, total=len(data), desc="Samples", leave=False)] + # t_bar.update(1) return data diff --git a/gifflar/utils.py b/gifflar/utils.py index 3db9f59..73987cb 100644 --- a/gifflar/utils.py +++ b/gifflar/utils.py @@ -116,7 +116,7 @@ def get_metrics( Accuracy(**metric_args), AUROC(**metric_args), MatthewsCorrCoef(**metric_args), - Sensitivity(**metric_args), + # Sensitivity(**metric_args), ]) return {"train": m.clone(prefix="train/"), "val": m.clone(prefix="val/"), "test": m.clone(prefix="test/")}