diff --git a/CHANGES.md b/CHANGES.md index c31a100..ffe3873 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,10 @@ and `Semantic versioning 2.0.0 `\_, with the exceptions that - versions above `1.0.0` will be numbered as `major.minor.patch`, as is typical +# Unreleased + +Add new query type `RNADiffusionMaskedQuery` + # 0.6.0 (2024-12-17) - More constraints on Boltz diff --git a/ginkgo_ai_client/queries.py b/ginkgo_ai_client/queries.py index e89960a..93ffc80 100644 --- a/ginkgo_ai_client/queries.py +++ b/ginkgo_ai_client/queries.py @@ -464,12 +464,45 @@ class MultimodalDiffusionMaskedResponse(ResponseBase): class RNADiffusionMaskedQuery(QueryBase): """A query to perform masked sampling using a mRNA diffusion model. + Parameters + ---------- + three_utr: str + The three UTR sequence, of the form "ATTGTAC..." + five_utr: str + The five UTR sequence, of the form "ATTGTAC..." + protein_sequence: str + The protein sequence, of the form "MLKKRRK...LP-" (the last character denotes a + stop codon). + species: str + The species, e.g. "HOMO_SAPIENS" + temperature: float, optional (default=1.0) + Sampling temperature, a value between 0 and 1. + decoding_order_strategy: str, optional (default="entropy") + Strategy for decoding order, must be either "max_prob" or "entropy". + unmaskings_per_step: int, optional (default=4) + Number of tokens to unmask per step + num_samples: int, optional (default=1) + Number of samples to generate + model: str + The model to use for the inference, "mrna-foundation" being the only choice + currently. + query_name: Optional[str] = None + The name of the query. It will appear in the API response and can be used to + handle exceptions. + + Returns + ------- + MultimodalDiffusionMaskedResponse + ``client.send_request(query)`` returns a ``MultimodalDiffusionMaskedResponse`` with + attributes ``samples`` (a list of predicted samples, with modality name: predicted sequence) + and ``query_name`` (the original query's name). + Examples -------- >>> query = RNADiffusionMaskedQuery( ... three_utr="ATTGTAC", ... five_utr="ATTGTAC", - ... protein_sequence="ATTGTAC", + ... protein_sequence="MLKKRRK", ... species="HOMO_SAPIENS", ... model="mrna-foundation", ... temperature=1.0, @@ -494,21 +527,19 @@ class RNADiffusionMaskedQuery(QueryBase): def to_request_params(self) -> Dict: data = { - "three_utr": self.three_utr.replace( - "", "[MASK]" - ), # UTR tokenizers require [MASK] but api client accepts for consistence across models - "five_utr": self.five_utr.replace("", "[MASK]"), + "three_utr": self.three_utr, + "five_utr": self.five_utr, "sequence_aa": self.protein_sequence, "species": self.species, "temperature": self.temperature, "decoding_order_strategy": self.decoding_order_strategy, - "num_to_decode_per_step": self.unmaskings_per_step, + "unmaskings_per_step": self.unmaskings_per_step, "num_samples": self.num_samples, } return { "model": self.model, "text": json.dumps(data), - "transforms": [{"type": "GENERATE"}], + "transforms": [{"type": "MRNA_DIFFUSION_GENERATE"}], } def parse_response(self, results: Dict) -> MultimodalDiffusionMaskedResponse: diff --git a/test/test_mrna_model.py b/test/test_mrna_model.py index a2afcfd..5473e01 100644 --- a/test/test_mrna_model.py +++ b/test/test_mrna_model.py @@ -4,17 +4,19 @@ GinkgoAIClient, ) + def test_get_mrna_species(): species = RNADiffusionMaskedQuery.get_species_dataframe() assert len(species) == 324 + def test_mrna_diffusion(): client = GinkgoAIClient() - three_utr="AAATTTGGGCC" - five_utr="AAATTTGGGCC" - protein_sequence="MAKS-" # '-' for end of sequence - species="HOMO_SAPIENS" + three_utr = "AAATTTGGGCC" + five_utr = "AAATTTGGGCC" + protein_sequence = "MAKS-" # '-' for end of sequence + species = "HOMO_SAPIENS" num_samples = 3 query = RNADiffusionMaskedQuery( three_utr=three_utr, @@ -25,7 +27,7 @@ def test_mrna_diffusion(): temperature=1.0, decoding_order_strategy="entropy", unmaskings_per_step=10, - num_samples=num_samples + num_samples=num_samples, ) response = client.send_request(query) @@ -39,10 +41,10 @@ def test_mrna_diffusion(): # check codon sequence verbatim. +1 because of stop codon assert len(sample["codon_sequence"]) == len(protein_sequence) * 3 - assert sample["codon_sequence"].startswith("ATG") # Start codon - assert sample["codon_sequence"][-3:] in ["TAA","TAG","TGA"] # stop codon + assert sample["codon_sequence"].startswith("ATG") # Start codon + assert sample["codon_sequence"][-3:] in ["TAA", "TAG", "TGA"] # stop codon # should translate translated = str(Seq(sample["codon_sequence"]).translate()) - print(translated, protein_sequence) - assert translated.replace("*","-") == protein_sequence \ No newline at end of file + print(translated, protein_sequence) + assert translated.replace("*", "-") == protein_sequence