Skip to content

Commit

Permalink
now transforming [MASK]-> <mask> and not the other way round
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentin Zulkower committed Jan 13, 2025
1 parent 81bc0ee commit 9b782a9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
7 changes: 3 additions & 4 deletions ginkgo_ai_client/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,9 @@ class RNADiffusionMaskedQuery(QueryBase):
def to_request_params(self) -> Dict:

data = {
"three_utr": self.three_utr.replace(
"<mask>", "[MASK]"
), # UTR tokenizers require [MASK] but api client accepts <mask> for consistence across models
"five_utr": self.five_utr.replace("<mask>", "[MASK]"),
# Many people in the field use [MASK] but our API client uses <mask> for all models
"three_utr": self.three_utr.replace("[MASK]", "<mask>"),
"five_utr": self.five_utr.replace("[MASK]", "<mask>"),
"sequence_aa": self.protein_sequence,
"species": self.species,
"temperature": self.temperature,
Expand Down
20 changes: 11 additions & 9 deletions test/test_mrna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
GinkgoAIClient,
)


def test_get_mrna_species():

species = RNADiffusionMaskedQuery.get_species_dataframe()
assert len(species) == 324


def test_mrna_diffusion():
client = GinkgoAIClient()
three_utr="AAA<mask>TTTGGGCC<mask><mask>"
five_utr="AAA<mask>TTTGGGCC<mask><mask>"
protein_sequence="MAKS-" # '-' for end of sequence
species="HOMO_SAPIENS"
three_utr = "AAA<mask>TTTGGGCC<mask><mask>"
five_utr = "AAA<mask>TTTGGGCC<mask><mask>"
protein_sequence = "MAKS-" # '-' for end of sequence
species = "HOMO_SAPIENS"
num_samples = 3
query = RNADiffusionMaskedQuery(
three_utr=three_utr,
Expand All @@ -25,7 +27,7 @@ def test_mrna_diffusion():
temperature=1.0,
decoding_order_strategy="entropy",
unmaskings_per_step=10,
num_samples=num_samples
num_samples=num_samples,
)

response = client.send_request(query)
Expand All @@ -39,10 +41,10 @@ def test_mrna_diffusion():

# check codon sequence verbatim. +1 because of stop codon
assert len(sample["codon_sequence"]) == len(protein_sequence) * 3
assert sample["codon_sequence"].startswith("ATG") # Start codon
assert sample["codon_sequence"][-3:] in ["TAA","TAG","TGA"] # stop codon
assert sample["codon_sequence"].startswith("ATG") # Start codon
assert sample["codon_sequence"][-3:] in ["TAA", "TAG", "TGA"] # stop codon

# should translate
translated = str(Seq(sample["codon_sequence"]).translate())
print(translated, protein_sequence)
assert translated.replace("*","-") == protein_sequence
print(translated, protein_sequence)
assert translated.replace("*", "-") == protein_sequence

0 comments on commit 9b782a9

Please sign in to comment.