diff --git a/enformer_pytorch/enformer_pytorch.py b/enformer_pytorch/enformer_pytorch.py index 7e96e83..895ef6e 100644 --- a/enformer_pytorch/enformer_pytorch.py +++ b/enformer_pytorch/enformer_pytorch.py @@ -373,6 +373,7 @@ def forward( target = None, return_corr_coef = False, return_embeddings = False, + return_only_embeddings = False, head = None ): dtype = x.dtype @@ -386,12 +387,15 @@ def forward( x = rearrange(x, '... -> () ...') x = self._trunk(x) - out = map_values(lambda fn: fn(x), self._heads) if no_batch: - out = map_values(lambda t: rearrange(t, '() ... -> ...'), out) x = rearrange(x, '() ... -> ...') + if return_only_embeddings: + return x + + out = map_values(lambda fn: fn(x), self._heads) + if exists(head): assert head in self._heads, f'head {head} not found' out = out[head] diff --git a/enformer_pytorch/finetune.py b/enformer_pytorch/finetune.py index e15c110..05256a6 100644 --- a/enformer_pytorch/finetune.py +++ b/enformer_pytorch/finetune.py @@ -53,7 +53,7 @@ def forward( enformer_context = freeze_batchnorm_context(self.enformer) if not freeze_enformer else torch.no_grad() with enformer_context: - _, embeddings = self.enformer(seq, return_embeddings = True) + embeddings = self.enformer(seq, return_only_embeddings = True) if freeze_enformer: embeddings.detach_() @@ -90,7 +90,7 @@ def forward( enformer_context = freeze_batchnorm_context(self.enformer) if not freeze_enformer else torch.no_grad() with enformer_context: - _, embeddings = self.enformer(seq, return_embeddings = True) + embeddings = self.enformer(seq, return_only_embeddings = True) if freeze_enformer: embeddings.detach_() diff --git a/setup.py b/setup.py index 6c0a672..0cecb18 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'enformer-pytorch', packages = find_packages(exclude=[]), - version = '0.1.5', + version = '0.1.6', license='MIT', description = 'Enformer - Pytorch', author = 'Phil Wang',