Skip to content

Commit

Permalink
Merge pull request #11 from ginkgobioworks/add_lcdna_endpoint
Browse files Browse the repository at this point in the history
Add lcdna and abdiffusion endpoints
  • Loading branch information
Zulko authored Dec 13, 2024
2 parents 2c97718 + a7536c7 commit 81b19ee
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 4 deletions.
3 changes: 3 additions & 0 deletions ginkgo_ai_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
MaskedInferenceQuery,
MeanEmbeddingQuery,
PromoterActivityQuery,
DiffusionMaskedQuery,
DiffusionMaskedResponse,
)

__all__ = [
"GinkgoAIClient",
"MaskedInferenceQuery",
"MeanEmbeddingQuery",
"PromoterActivityQuery",
"DiffusionMaskedQuery",
]
102 changes: 102 additions & 0 deletions ginkgo_ai_client/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def write_to_jsonl(self, path: str):
"esm2-650M": "protein",
"esm2-3B": "protein",
"ginkgo-maskedlm-3utr-v1": "dna",
"lcdna": "nucleotide",
"abdiffusion": "protein",
}

_maskedlm_models_properties_str = "\n".join(
Expand All @@ -68,6 +70,11 @@ def _validate_model_and_sequence(model, sequence: str, allow_masks=False):
raise ValueError(
f"Model {model} requires the sequence to only contain ATGC characters"
)
elif sequence_type == "nucleotide":
if not set(sequence.lower()).issubset({"a", "t", "g", "c", "r", "y", "s", "w", "k", "m", "b", "d", "h", "v", "n"}):
raise ValueError(
f"Model {model} requires the sequence to only contain valid IUPAC nucleotide characters"
)
elif sequence_type == "protein":
if not set(sequence).issubset(set("ACDEFGHIKLMNPQRSTVWY")):
raise ValueError("Sequence must contain only protein characters")
Expand Down Expand Up @@ -381,3 +388,98 @@ def list_with_promoter_from_fasta(
model=model,
)
return list(iterator)


class DiffusionMaskedResponse(ResponseBase):
"""A response to a DiffusionMaskedQuery, with attributes `sequence` (the predicted
sequence) and `query_name` (the original query's name).
"""

sequence: str
query_name: Optional[str] = None


class DiffusionMaskedQuery(QueryBase):
"""A query to perform masked sampling using a diffusion model.
Parameters
----------
sequence: str
Input sequence for masked sampling. The sequence may contain "<mask>" tokens.
temperature: float, optional (default=0.5)
Sampling temperature, a value between 0 and 1.
decoding_order_strategy: str, optional (default="entropy")
Strategy for decoding order, must be either "max_prob" or "entropy".
unmaskings_per_step: int, optional (default=50)
Number of tokens to unmask per step, an integer between 1 and 1000.
model: str
The model to use for the inference.
query_name: Optional[str] = None
The name of the query. It will appear in the API response and can be used to handle exceptions.
Returns
-------
DiffusionMaskedResponse
``client.send_request(query)`` returns a ``DiffusionMaskedResponse`` with attributes
``sequence`` (the predicted sequence) and ``query_name`` (the original query's name).
Examples
--------
>>> query = DiffusionMaskedQuery(
... sequence="ATTG<mask>TAC",
... model="lcdna",
... temperature=0.7,
... decoding_order_strategy="entropy",
... unmaskings_per_step=20,
... )
>>> client.send_request(query)
DiffusionMaskedResponse(sequence="ATTGCGTAC", query_name=None)
"""

sequence: str
temperature: float = 0.5
decoding_order_strategy: str = "entropy"
unmaskings_per_step: int = 50
model: str
query_name: Optional[str] = None

def to_request_params(self) -> Dict:
data = {
"sequence": self.sequence,
"temperature": self.temperature,
"decoding_order_strategy": self.decoding_order_strategy,
"unmaskings_per_step": self.unmaskings_per_step,
}
return {
"model": self.model,
"text": json.dumps(data),
"transforms": [{"type": "DIFFUSION_GENERATE"}],
}

def parse_response(self, results: Dict) -> DiffusionMaskedResponse:
return DiffusionMaskedResponse(
sequence=results["sequence"][0],
query_name=self.query_name,
)

@pydantic.model_validator(mode="after")
def validate_query(cls, query):
sequence, model = query.sequence, query.model
# Validate sequence and model compatibility
_validate_model_and_sequence(
model=model,
sequence=sequence,
allow_masks=True,
)
# Validate temperature
if not 0 <= query.temperature <= 1:
raise ValueError("temperature must be between 0 and 1")
# Validate decoding_order_strategy
if query.decoding_order_strategy not in ["max_prob", "entropy"]:
raise ValueError(
"decoding_order_strategy must be 'max_prob' or 'entropy'"
)
# Validate unmaskings_per_step
if not 1 <= query.unmaskings_per_step <= 1000:
raise ValueError("unmaskings_per_step must be between 1 and 1000")
return query
10 changes: 6 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ biopython==1.82.0
pytest-xdist==3.6.1
pytest-cov==4.0.0

sphinx==8.1.3,
docutils==0.21.2,
myst-parser==4.0.0,
shibuya==2024.10.15
sphinx==8.1.3
docutils==0.21.2
myst-parser==4.0.0
shibuya==2024.10.15
tqdm==4.67.1
pydantic==2.10.3
22 changes: 22 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MaskedInferenceQuery,
MeanEmbeddingQuery,
PromoterActivityQuery,
DiffusionMaskedQuery,
)


Expand All @@ -14,6 +15,7 @@
("ginkgo-aa0-650M", "MCL<mask>YAFVATDA<mask>DDT", "MCLLYAFVATDADDDT"),
("esm2-650M", "MCL<mask>YAFVATDA<mask>DDT", "MCLLYAFVATDAADDT"),
("ginkgo-maskedlm-3utr-v1", "ATTG<mask>G", "ATTGGG"),
("lcdna", "ATRGAyAtg<mask>TAC<mask>", "atggatatgtta<unk>"),
],
)
def test_masked_inference(model, sequence, expected_sequence):
Expand Down Expand Up @@ -61,3 +63,23 @@ def test_promoter_activity():
response = client.send_request(query)
assert "heart" in response.activity_by_tissue
assert "liver" in response.activity_by_tissue

@pytest.mark.parametrize(
"model, sequence",
[
("lcdna", "ATRGAyAtg<mask>TAC<mask>"),
("abdiffusion", "MCL<mask>YAFVATDA<mask>DDT"),
],
)
def test_diffusion_masked_inference(model, sequence):
client = GinkgoAIClient()
query = DiffusionMaskedQuery(
sequence=sequence, #upper and lower cases
model=model,
temperature=0.5,
decoding_order_strategy="entropy",
unmaskings_per_step=2,
)
response = client.send_request(query)
assert isinstance(response.sequence, str)
assert "<mask>" not in response.sequence

0 comments on commit 81b19ee

Please sign in to comment.