Skip to content

Commit

Permalink
add seq position to foldcomp dataset, minor tweaks to fc->pyg conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Arian Jamasb committed Feb 6, 2024
1 parent 6e0be18 commit e1ab15f
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions graphein/ml/datasets/foldcomp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def __init__(
self._get_indices()

super().__init__(
root=self.root, transform=self.transform, pre_transform=None # type: ignore
root=self.root,
transform=self.transform,
pre_transform=None, # type: ignore
)

@property
Expand All @@ -202,9 +204,7 @@ def processed_file_names(self):
def download(self):
"""Downloads foldcomp database if not already downloaded."""

if not all(
os.path.exists(self.root / f) for f in self._database_files
):
if not all(os.path.exists(self.root / f) for f in self._database_files):
log.info(f"Downloading FoldComp dataset {self.database}...")
curr_dir = os.getcwd()
os.chdir(self.root)
Expand Down Expand Up @@ -249,9 +249,7 @@ def _get_indices(self):
]
# Sub sample
log.info(f"Sampling fraction: {self.fraction}...")
accessions = random.sample(
accessions, int(len(accessions) * self.fraction)
)
accessions = random.sample(accessions, int(len(accessions) * self.fraction))
self.ids = accessions
log.info("Creating index...")
indices = dict(enumerate(accessions))
Expand All @@ -264,7 +262,9 @@ def process(self):
# Open the database
log.info("Opening database...")
if self.ids is not None:
self.db = foldcomp.open(self.root / self.database, ids=self.ids, decompress=False) # type: ignore
self.db = foldcomp.open(
self.root / self.database, ids=self.ids, decompress=False
) # type: ignore
else:
self.db = foldcomp.open(self.root / self.database, decompress=False) # type: ignore

Expand All @@ -275,9 +275,10 @@ def fc_to_pyg(data: Dict[str, Any], name: Optional[str] = None) -> Protein:
residue_type = torch.tensor(
[STANDARD_AMINO_ACIDS.index(res) for res in data["residues"]],
)
n_res = len(res)

# Get residue numbers
res_num = [i for i, _ in enumerate(res)]
res_num = np.arange(n_res)

# Get list of atom types
atom_types = []
Expand All @@ -292,7 +293,7 @@ def fc_to_pyg(data: Dict[str, Any], name: Optional[str] = None) -> Protein:
atom_idx = np.array([ATOM_NUMBERING[atm] for atm in atom_types])

# Initialize coordinates
coords = np.ones((len(res), 37, 3)) * 1e-5
coords = np.ones((n_res, 37, 3)) * 1e-5

res_idx = np.repeat(res_num, atom_counts)
coords[res_idx, atom_idx, :] = np.array(data["coordinates"])
Expand All @@ -302,11 +303,12 @@ def fc_to_pyg(data: Dict[str, Any], name: Optional[str] = None) -> Protein:
coords=torch.from_numpy(coords).float(),
residues=res,
residue_id=[f"A:{m}:{str(n)}" for m, n in zip(res, res_num)],
chains=torch.zeros(len(res)),
chains=torch.zeros(n_res),
residue_type=residue_type.long(),
b_factor=torch.from_numpy(b_factor).float(),
id=name,
x=torch.zeros(len(res)),
x=torch.zeros(n_res),
seq_pos=torch.from_numpy(res_num).unsqueeze(-1),
)

def len(self) -> int:
Expand Down Expand Up @@ -380,9 +382,7 @@ def __init__(
self.val_split = val_split
self.test_split = test_split
self.transform = (
self._compose_transforms(transform)
if transform is not None
else None
self._compose_transforms(transform) if transform is not None else None
)

if (
Expand Down Expand Up @@ -421,9 +421,7 @@ def _get_indices(self):
self.ids = ds.ids
ds.db.close()

def _split_data(
self, train_split: float, val_split: float, test_split: float
):
def _split_data(self, train_split: float, val_split: float, test_split: float):
"""Split the database into non-overlapping train, validation and test"""
if not hasattr(self, "ids"):
self._get_indices()
Expand Down

0 comments on commit e1ab15f

Please sign in to comment.