Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Joeres authored and Roman Joeres committed Aug 15, 2024
1 parent 7aced6d commit 4e9327f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
4 changes: 2 additions & 2 deletions configs/downstream/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ model:
- name: gifflar
hidden_dim: 1024
batch_size: 33
num_layers: 10
num_layers: 5
epochs: 100
learning_rate: 0.001
optimizer: Adam
suffix: _1024_12
suffix: _1024_5

3 changes: 2 additions & 1 deletion gifflar/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ def process_(self, data, path_idx: int = 0) -> None:
if self.pre_filter is not None:
data = [d for d in data if self.pre_filter(d)]
if self.pre_transform is not None:
data = [self.pre_transform(d) for d in data]
# data = [self.pre_transform(d) for d in data]
data = self.pre_transform(data)

torch.save((data, self.dataset_args), self.processed_paths[path_idx])

Expand Down
23 changes: 16 additions & 7 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,22 @@ def __call__(self, data):

class TQDMCompose(Compose):
def forward(self, data: Union[Data, HeteroData]):
with tqdm(total=len(self.transforms), desc="Transforms") as t_bar:
for transform in self.transforms:
if isinstance(data, (list, tuple)):
data = transform(data)
else:
data = [transform(d) for d in tqdm(data, total=len(data), desc="Samples", leave=False)]
t_bar.update(1)
# print(self.transforms)
# print(len(self.transforms))
# print(len(data))
# with tqdm(total=len(self.transforms), desc="Transforms") as t_bar:
for transform in tqdm(self.transforms, desc=f"Transform"):
if not isinstance(data, (list, tuple)):
data = transform(data)
else:
# data = [transform(d) for d in data]
t_data = []
for d in tqdm(data, leave=False):
t_data.append(transform(d))
data = t_data
# s_bar.update(1)
# data = [transform(d) for d in tqdm(data, total=len(data), desc="Samples", leave=False)]
# t_bar.update(1)
return data


Expand Down
2 changes: 1 addition & 1 deletion gifflar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_metrics(
Accuracy(**metric_args),
AUROC(**metric_args),
MatthewsCorrCoef(**metric_args),
Sensitivity(**metric_args),
# Sensitivity(**metric_args),
])
return {"train": m.clone(prefix="train/"), "val": m.clone(prefix="val/"), "test": m.clone(prefix="test/")}

Expand Down

0 comments on commit 4e9327f

Please sign in to comment.