Skip to content

Commit

Permalink
add ability to only return embeddings, without human and mouse head c…
Browse files Browse the repository at this point in the history
…alculation
  • Loading branch information
lucidrains committed Dec 26, 2021
1 parent f5072e1 commit 6f72ed4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
8 changes: 6 additions & 2 deletions enformer_pytorch/enformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def forward(
target = None,
return_corr_coef = False,
return_embeddings = False,
return_only_embeddings = False,
head = None
):
dtype = x.dtype
Expand All @@ -386,12 +387,15 @@ def forward(
x = rearrange(x, '... -> () ...')

x = self._trunk(x)
out = map_values(lambda fn: fn(x), self._heads)

if no_batch:
out = map_values(lambda t: rearrange(t, '() ... -> ...'), out)
x = rearrange(x, '() ... -> ...')

if return_only_embeddings:
return x

out = map_values(lambda fn: fn(x), self._heads)

if exists(head):
assert head in self._heads, f'head {head} not found'
out = out[head]
Expand Down
4 changes: 2 additions & 2 deletions enformer_pytorch/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(
enformer_context = freeze_batchnorm_context(self.enformer) if not freeze_enformer else torch.no_grad()

with enformer_context:
_, embeddings = self.enformer(seq, return_embeddings = True)
embeddings = self.enformer(seq, return_only_embeddings = True)

if freeze_enformer:
embeddings.detach_()
Expand Down Expand Up @@ -90,7 +90,7 @@ def forward(
enformer_context = freeze_batchnorm_context(self.enformer) if not freeze_enformer else torch.no_grad()

with enformer_context:
_, embeddings = self.enformer(seq, return_embeddings = True)
embeddings = self.enformer(seq, return_only_embeddings = True)

if freeze_enformer:
embeddings.detach_()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.5',
version = '0.1.6',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6f72ed4

Please sign in to comment.