From e1ab15fb00b349aed87d1684313f5ed5ded6c90c Mon Sep 17 00:00:00 2001 From: Arian Jamasb Date: Tue, 6 Feb 2024 09:56:09 +0100 Subject: [PATCH] add seq position to foldcomp dataset, minor tweaks to fc->pyg conversion --- graphein/ml/datasets/foldcomp_dataset.py | 34 +++++++++++------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/graphein/ml/datasets/foldcomp_dataset.py b/graphein/ml/datasets/foldcomp_dataset.py index 8d033518..570f7097 100644 --- a/graphein/ml/datasets/foldcomp_dataset.py +++ b/graphein/ml/datasets/foldcomp_dataset.py @@ -180,7 +180,9 @@ def __init__( self._get_indices() super().__init__( - root=self.root, transform=self.transform, pre_transform=None # type: ignore + root=self.root, + transform=self.transform, + pre_transform=None, # type: ignore ) @property @@ -202,9 +204,7 @@ def processed_file_names(self): def download(self): """Downloads foldcomp database if not already downloaded.""" - if not all( - os.path.exists(self.root / f) for f in self._database_files - ): + if not all(os.path.exists(self.root / f) for f in self._database_files): log.info(f"Downloading FoldComp dataset {self.database}...") curr_dir = os.getcwd() os.chdir(self.root) @@ -249,9 +249,7 @@ def _get_indices(self): ] # Sub sample log.info(f"Sampling fraction: {self.fraction}...") - accessions = random.sample( - accessions, int(len(accessions) * self.fraction) - ) + accessions = random.sample(accessions, int(len(accessions) * self.fraction)) self.ids = accessions log.info("Creating index...") indices = dict(enumerate(accessions)) @@ -264,7 +262,9 @@ def process(self): # Open the database log.info("Opening database...") if self.ids is not None: - self.db = foldcomp.open(self.root / self.database, ids=self.ids, decompress=False) # type: ignore + self.db = foldcomp.open( + self.root / self.database, ids=self.ids, decompress=False + ) # type: ignore else: self.db = foldcomp.open(self.root / self.database, decompress=False) # type: ignore @@ -275,9 +275,10 @@ def fc_to_pyg(data: Dict[str, Any], name: Optional[str] = None) -> Protein: residue_type = torch.tensor( [STANDARD_AMINO_ACIDS.index(res) for res in data["residues"]], ) + n_res = len(res) # Get residue numbers - res_num = [i for i, _ in enumerate(res)] + res_num = np.arange(n_res) # Get list of atom types atom_types = [] @@ -292,7 +293,7 @@ def fc_to_pyg(data: Dict[str, Any], name: Optional[str] = None) -> Protein: atom_idx = np.array([ATOM_NUMBERING[atm] for atm in atom_types]) # Initialize coordinates - coords = np.ones((len(res), 37, 3)) * 1e-5 + coords = np.ones((n_res, 37, 3)) * 1e-5 res_idx = np.repeat(res_num, atom_counts) coords[res_idx, atom_idx, :] = np.array(data["coordinates"]) @@ -302,11 +303,12 @@ def fc_to_pyg(data: Dict[str, Any], name: Optional[str] = None) -> Protein: coords=torch.from_numpy(coords).float(), residues=res, residue_id=[f"A:{m}:{str(n)}" for m, n in zip(res, res_num)], - chains=torch.zeros(len(res)), + chains=torch.zeros(n_res), residue_type=residue_type.long(), b_factor=torch.from_numpy(b_factor).float(), id=name, - x=torch.zeros(len(res)), + x=torch.zeros(n_res), + seq_pos=torch.from_numpy(res_num).unsqueeze(-1), ) def len(self) -> int: @@ -380,9 +382,7 @@ def __init__( self.val_split = val_split self.test_split = test_split self.transform = ( - self._compose_transforms(transform) - if transform is not None - else None + self._compose_transforms(transform) if transform is not None else None ) if ( @@ -421,9 +421,7 @@ def _get_indices(self): self.ids = ds.ids ds.db.close() - def _split_data( - self, train_split: float, val_split: float, test_split: float - ): + def _split_data(self, train_split: float, val_split: float, test_split: float): """Split the database into non-overlapping train, validation and test""" if not hasattr(self, "ids"): self._get_indices()