Skip to content

Commit

Permalink
begin work on fine-tuning from base enformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 26, 2021
1 parent 2a706e6 commit 41f107c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 2 deletions.
40 changes: 39 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ For training, you can directly pass the head and target in to get the poisson lo

```python
import torch
import torch.nn.functional as F
from enformer_pytorch import Enformer

model = Enformer(
Expand Down Expand Up @@ -134,6 +133,42 @@ model = load_pretrained_model('preview')
# do your fine-tuning
```

## Fine-tuning (wip)

This repository will also allow for easy fine-tuning of Enformer. For starters, the following example shows a single step for finetuning on contextual data (cell type, transcription factor, etc)

```python
import torch
from enformer_pytorch import Enformer
from enformer_pytorch.finetune import ContextAdapterWrapper

enformer = Enformer(
dim = 1536,
depth = 1,
heads = 8,
target_length = 200,
)

model = ContextAdapterWrapper(
enformer = enformer,
enformer_dim = 1536,
context_dim = 1024
).cuda()

seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()

target = torch.randn(1, 200, 4).cuda() # 4 tracks
context = torch.randn(4, 1024).cuda() # 4 contexts for the different 'tracks'

loss = model(
seq,
context = context,
target = target
)

loss.backward()
```

## Appreciation

Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for providing the resources to retrain the model in an acceptable amount of time
Expand All @@ -144,6 +179,9 @@ Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for
- [x] add loss wrapper with poisson loss
- [x] move the metrics code over to pytorch as well
- [x] train enformer model
- [ ] allow for plain fine-tune with fixed static context
- [ ] build context manager for fine-tuning with unfrozen enformer but with frozen batchnorm
- [ ] allow for fine tuning with only unfrozen layernorms (technique from fine tuning transformers)

## Citations

Expand Down
53 changes: 53 additions & 0 deletions enformer_pytorch/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from contextlib import contextmanager
from torch import nn, einsum
from einops import rearrange
from enformer_pytorch.enformer_pytorch import Enformer, poisson_loss

def exists(val):
return val is not None

@contextmanager
def null_context():
yield

class ContextAdapterWrapper(nn.Module):
def __init__(
self,
*,
enformer,
enformer_dim,
context_dim
):
super().__init__()
assert isinstance(enformer, Enformer)
self.enformer = enformer

self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_dim * 2))
self.to_context_bias = nn.Parameter(torch.randn(context_dim))

def forward(
self,
seq,
*,
context,
target = None,
freeze_enformer = False
):
enformer_context = null_context if freeze_enformer else torch.no_grad

with enformer_context():
_, embeddings = self.enformer(seq, return_embeddings = True)

if freeze_enformer:
embeddings.detach_()

weights = einsum('t d, d e -> t e', context, self.to_context_weights)
bias = einsum('t d, d -> t', context, self.to_context_bias)

pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias

if not exists(target):
return pred

return poisson_loss(pred, target)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 41f107c

Please sign in to comment.