Skip to content

Commit

Permalink
Merge pull request #22 from ginkgobioworks/vz/mrna-foundation-tweaks-2
Browse files Browse the repository at this point in the history
tweaks to mrna_foundation diffusion parameters
  • Loading branch information
Zulko authored Jan 14, 2025
2 parents 212b842 + 1ecc9b8 commit 9131ba0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and `Semantic versioning 2.0.0 <http://semver.org/>`\_, with the exceptions that
- versions above `1.0.0` will be numbered as `major.minor.patch`, as is
typical

# Unreleased

Add new query type `RNADiffusionMaskedQuery`

# 0.6.0 (2024-12-17)

- More constraints on Boltz
Expand Down
45 changes: 38 additions & 7 deletions ginkgo_ai_client/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,45 @@ class MultimodalDiffusionMaskedResponse(ResponseBase):
class RNADiffusionMaskedQuery(QueryBase):
"""A query to perform masked sampling using a mRNA diffusion model.
Parameters
----------
three_utr: str
The three UTR sequence, of the form "ATTG<mask>TAC..."
five_utr: str
The five UTR sequence, of the form "ATTG<mask>TAC..."
protein_sequence: str
The protein sequence, of the form "MLKKRRK...LP-" (the last character denotes a
stop codon).
species: str
The species, e.g. "HOMO_SAPIENS"
temperature: float, optional (default=1.0)
Sampling temperature, a value between 0 and 1.
decoding_order_strategy: str, optional (default="entropy")
Strategy for decoding order, must be either "max_prob" or "entropy".
unmaskings_per_step: int, optional (default=4)
Number of tokens to unmask per step
num_samples: int, optional (default=1)
Number of samples to generate
model: str
The model to use for the inference, "mrna-foundation" being the only choice
currently.
query_name: Optional[str] = None
The name of the query. It will appear in the API response and can be used to
handle exceptions.
Returns
-------
MultimodalDiffusionMaskedResponse
``client.send_request(query)`` returns a ``MultimodalDiffusionMaskedResponse`` with
attributes ``samples`` (a list of predicted samples, with modality name: predicted sequence)
and ``query_name`` (the original query's name).
Examples
--------
>>> query = RNADiffusionMaskedQuery(
... three_utr="ATTG<mask>TAC",
... five_utr="ATTG<mask>TAC",
... protein_sequence="ATTG<mask>TAC",
... protein_sequence="MLKKRRK",
... species="HOMO_SAPIENS",
... model="mrna-foundation",
... temperature=1.0,
Expand All @@ -494,21 +527,19 @@ class RNADiffusionMaskedQuery(QueryBase):
def to_request_params(self) -> Dict:

data = {
"three_utr": self.three_utr.replace(
"<mask>", "[MASK]"
), # UTR tokenizers require [MASK] but api client accepts <mask> for consistence across models
"five_utr": self.five_utr.replace("<mask>", "[MASK]"),
"three_utr": self.three_utr,
"five_utr": self.five_utr,
"sequence_aa": self.protein_sequence,
"species": self.species,
"temperature": self.temperature,
"decoding_order_strategy": self.decoding_order_strategy,
"num_to_decode_per_step": self.unmaskings_per_step,
"unmaskings_per_step": self.unmaskings_per_step,
"num_samples": self.num_samples,
}
return {
"model": self.model,
"text": json.dumps(data),
"transforms": [{"type": "GENERATE"}],
"transforms": [{"type": "MRNA_DIFFUSION_GENERATE"}],
}

def parse_response(self, results: Dict) -> MultimodalDiffusionMaskedResponse:
Expand Down
20 changes: 11 additions & 9 deletions test/test_mrna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
GinkgoAIClient,
)


def test_get_mrna_species():

species = RNADiffusionMaskedQuery.get_species_dataframe()
assert len(species) == 324


def test_mrna_diffusion():
client = GinkgoAIClient()
three_utr="AAA<mask>TTTGGGCC<mask><mask>"
five_utr="AAA<mask>TTTGGGCC<mask><mask>"
protein_sequence="MAKS-" # '-' for end of sequence
species="HOMO_SAPIENS"
three_utr = "AAA<mask>TTTGGGCC<mask><mask>"
five_utr = "AAA<mask>TTTGGGCC<mask><mask>"
protein_sequence = "MAKS-" # '-' for end of sequence
species = "HOMO_SAPIENS"
num_samples = 3
query = RNADiffusionMaskedQuery(
three_utr=three_utr,
Expand All @@ -25,7 +27,7 @@ def test_mrna_diffusion():
temperature=1.0,
decoding_order_strategy="entropy",
unmaskings_per_step=10,
num_samples=num_samples
num_samples=num_samples,
)

response = client.send_request(query)
Expand All @@ -39,10 +41,10 @@ def test_mrna_diffusion():

# check codon sequence verbatim. +1 because of stop codon
assert len(sample["codon_sequence"]) == len(protein_sequence) * 3
assert sample["codon_sequence"].startswith("ATG") # Start codon
assert sample["codon_sequence"][-3:] in ["TAA","TAG","TGA"] # stop codon
assert sample["codon_sequence"].startswith("ATG") # Start codon
assert sample["codon_sequence"][-3:] in ["TAA", "TAG", "TGA"] # stop codon

# should translate
translated = str(Seq(sample["codon_sequence"]).translate())
print(translated, protein_sequence)
assert translated.replace("*","-") == protein_sequence
print(translated, protein_sequence)
assert translated.replace("*", "-") == protein_sequence

0 comments on commit 9131ba0

Please sign in to comment.