diff --git a/CHANGES.md b/CHANGES.md
index c31a100..ffe3873 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -9,6 +9,10 @@ and `Semantic versioning 2.0.0 `\_, with the exceptions that
- versions above `1.0.0` will be numbered as `major.minor.patch`, as is
typical
+# Unreleased
+
+Add new query type `RNADiffusionMaskedQuery`
+
# 0.6.0 (2024-12-17)
- More constraints on Boltz
diff --git a/ginkgo_ai_client/queries.py b/ginkgo_ai_client/queries.py
index e89960a..93ffc80 100644
--- a/ginkgo_ai_client/queries.py
+++ b/ginkgo_ai_client/queries.py
@@ -464,12 +464,45 @@ class MultimodalDiffusionMaskedResponse(ResponseBase):
class RNADiffusionMaskedQuery(QueryBase):
"""A query to perform masked sampling using a mRNA diffusion model.
+ Parameters
+ ----------
+ three_utr: str
+ The three UTR sequence, of the form "ATTGTAC..."
+ five_utr: str
+ The five UTR sequence, of the form "ATTGTAC..."
+ protein_sequence: str
+ The protein sequence, of the form "MLKKRRK...LP-" (the last character denotes a
+ stop codon).
+ species: str
+ The species, e.g. "HOMO_SAPIENS"
+ temperature: float, optional (default=1.0)
+ Sampling temperature, a value between 0 and 1.
+ decoding_order_strategy: str, optional (default="entropy")
+ Strategy for decoding order, must be either "max_prob" or "entropy".
+ unmaskings_per_step: int, optional (default=4)
+ Number of tokens to unmask per step
+ num_samples: int, optional (default=1)
+ Number of samples to generate
+ model: str
+ The model to use for the inference, "mrna-foundation" being the only choice
+ currently.
+ query_name: Optional[str] = None
+ The name of the query. It will appear in the API response and can be used to
+ handle exceptions.
+
+ Returns
+ -------
+ MultimodalDiffusionMaskedResponse
+ ``client.send_request(query)`` returns a ``MultimodalDiffusionMaskedResponse`` with
+ attributes ``samples`` (a list of predicted samples, with modality name: predicted sequence)
+ and ``query_name`` (the original query's name).
+
Examples
--------
>>> query = RNADiffusionMaskedQuery(
... three_utr="ATTGTAC",
... five_utr="ATTGTAC",
- ... protein_sequence="ATTGTAC",
+ ... protein_sequence="MLKKRRK",
... species="HOMO_SAPIENS",
... model="mrna-foundation",
... temperature=1.0,
@@ -494,21 +527,19 @@ class RNADiffusionMaskedQuery(QueryBase):
def to_request_params(self) -> Dict:
data = {
- "three_utr": self.three_utr.replace(
- "", "[MASK]"
- ), # UTR tokenizers require [MASK] but api client accepts for consistence across models
- "five_utr": self.five_utr.replace("", "[MASK]"),
+ "three_utr": self.three_utr,
+ "five_utr": self.five_utr,
"sequence_aa": self.protein_sequence,
"species": self.species,
"temperature": self.temperature,
"decoding_order_strategy": self.decoding_order_strategy,
- "num_to_decode_per_step": self.unmaskings_per_step,
+ "unmaskings_per_step": self.unmaskings_per_step,
"num_samples": self.num_samples,
}
return {
"model": self.model,
"text": json.dumps(data),
- "transforms": [{"type": "GENERATE"}],
+ "transforms": [{"type": "MRNA_DIFFUSION_GENERATE"}],
}
def parse_response(self, results: Dict) -> MultimodalDiffusionMaskedResponse:
diff --git a/test/test_mrna_model.py b/test/test_mrna_model.py
index a2afcfd..5473e01 100644
--- a/test/test_mrna_model.py
+++ b/test/test_mrna_model.py
@@ -4,17 +4,19 @@
GinkgoAIClient,
)
+
def test_get_mrna_species():
species = RNADiffusionMaskedQuery.get_species_dataframe()
assert len(species) == 324
+
def test_mrna_diffusion():
client = GinkgoAIClient()
- three_utr="AAATTTGGGCC"
- five_utr="AAATTTGGGCC"
- protein_sequence="MAKS-" # '-' for end of sequence
- species="HOMO_SAPIENS"
+ three_utr = "AAATTTGGGCC"
+ five_utr = "AAATTTGGGCC"
+ protein_sequence = "MAKS-" # '-' for end of sequence
+ species = "HOMO_SAPIENS"
num_samples = 3
query = RNADiffusionMaskedQuery(
three_utr=three_utr,
@@ -25,7 +27,7 @@ def test_mrna_diffusion():
temperature=1.0,
decoding_order_strategy="entropy",
unmaskings_per_step=10,
- num_samples=num_samples
+ num_samples=num_samples,
)
response = client.send_request(query)
@@ -39,10 +41,10 @@ def test_mrna_diffusion():
# check codon sequence verbatim. +1 because of stop codon
assert len(sample["codon_sequence"]) == len(protein_sequence) * 3
- assert sample["codon_sequence"].startswith("ATG") # Start codon
- assert sample["codon_sequence"][-3:] in ["TAA","TAG","TGA"] # stop codon
+ assert sample["codon_sequence"].startswith("ATG") # Start codon
+ assert sample["codon_sequence"][-3:] in ["TAA", "TAG", "TGA"] # stop codon
# should translate
translated = str(Seq(sample["codon_sequence"]).translate())
- print(translated, protein_sequence)
- assert translated.replace("*","-") == protein_sequence
\ No newline at end of file
+ print(translated, protein_sequence)
+ assert translated.replace("*", "-") == protein_sequence