Skip to content

Commit

Permalink
feat: refactor prediction writer (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikey-vyas authored Aug 26, 2024
1 parent 041fce2 commit 9443a74
Show file tree
Hide file tree
Showing 13 changed files with 349 additions and 193 deletions.
56 changes: 47 additions & 9 deletions .github/workflows/predict-parallel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ on:
sample-only:
required: false
type: string
n-workers:
description: 'number of workers to use (max 50)'
required: true
default: '50'
type: string

permissions:
contents: read
Expand All @@ -21,25 +26,58 @@ jobs:
- name: Get start time
id: start-time
run: echo "start-time=$(date +%s)" >> "$GITHUB_OUTPUT"

# pull-model-image:
# runs-on: ubuntu-latest
# steps:
# - name: Remove unnecessary files
# run: |
# sudo rm -rf /usr/share/dotnet
# sudo rm -rf "$AGENT_TOOLSDIRECTORY"
# sudo rm -rf /opt/ghc
# sudo rm -rf "/usr/local/share/boost"
# - name: cache model Docker image
# id: cache
# uses: actions/cache@v3
# with:
# path: /tmp/${{ inputs.model-id }}.tar
# key: ${{ runner.os }}-docker-${{ inputs.model-id }}
# restore-keys: |
# ${{ runner.os }}-docker-${{ inputs.model-id }}

# - name: pull and save requested model image
# if: steps.cache.outputs.cache-hit != 'true'
# run: |
# docker pull ersiliaos/${{ inputs.model-id }}:latest
# docker save ersiliaos/${{ inputs.model-id }} -o /tmp/${{ inputs.model-id }}.tar


generate-matrix:
# needs: pull-model-image
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- id: set-matrix
run: |
start='1'
end=${{ inputs.n-workers }}
matrix=$(seq -s ',' $start $end)
echo "matrix=[${matrix}]" >> $GITHUB_OUTPUT
matrix-inference:
needs: generate-matrix
if: github.repository != 'ersilia-os/eos-template'

strategy:
matrix:
# numerator: [
# 1,2,3,4,5,6,7,8,9,10,
# 11,12,13,14,15,16,17,18,19,20,
# 21,22,23,24,25,26,27,28,29,30,
# 31,32,33,34,35,36,37,38,39,40,
# 41,42,43,44,45,46,47,48,49,50
# ]
numerator: [1]
numerator: ${{ fromJson(needs.generate-matrix.outputs.matrix) }}

uses: ./.github/workflows/predict.yml
with:
numerator: ${{ matrix.numerator }}
denominator: 50
denominator: ${{ inputs.n-workers }}
model-id: ${{ inputs.model-id }}
sample-only: ${{ inputs.sample-only }}
SHA: ${{ github.sha }}
Expand Down
50 changes: 35 additions & 15 deletions .github/workflows/predict.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ on:
inputs:
numerator:
required: true
type: number
type: string
denominator:
required: true
type: number
type: string
model-id:
required: true
type: string
sha:
required: true
type: string
sample-only:
required: false
type: string
SHA:
required: true
type: string

jobs:
infer-and-upload:
Expand All @@ -44,12 +44,11 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y make

- name: Set up Python environment
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Run make install
run: make install-prod
cache: 'pip'
- run: make install-prod

# we need this step as ersilia will use the default conda environment to run the example model during `ersilia serve`
# could get around this eventually if we only use conda for env management, but there are complexities around referencing a dev
Expand All @@ -67,12 +66,33 @@ jobs:
- name: Activate virtual environment
run: source .venv/bin/activate

- name: Remove unnecessary files
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo rm -rf /opt/ghc
sudo rm -rf "/usr/local/share/boost"
# - name: Restore cached Docker image
# id: cache
# uses: actions/cache@v3
# with:
# path: /tmp/${{ inputs.model-id }}.tar
# key: ${{ runner.os }}-docker-${{ inputs.model-id }}
# restore-keys: |
# ${{ runner.os }}-docker-${{ inputs.model-id }}

# - name: Load Docker image
# if: steps.cache.outputs.cache-hit == 'true'
# run: |
# docker load -i /tmp/${{ inputs.model-id }}.tar

- name: Run Python script to generate predictions and upload to S3
env:
MODEL_ID: ${{ inputs.model-id }}
SHA: ${{ inputs.SHA }}
numerator: ${{ inputs.numerator }}
sample-only: ${{ inputs.sample-only }}
GITHUB_REPOSITORY: ${{ github.event.repository.full_name }}
run: .venv/bin/python scripts/generate_predictions_with_aws_and_ersilia.py ci
INPUT_MODEL_ID: ${{ inputs.model-id }}
INPUT_SHA: ${{ inputs.SHA }}
INPUT_NUMERATOR: ${{ inputs.numerator }}
INPUT_DENOMINATOR: ${{ inputs.denominator }}
INPUT_SAMPLE_ONLY: ${{ inputs.sample-only }}
run: .venv/bin/python scripts/generate_predictions.py --env prod
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ data/
.DS_Store

# VSCode
.vscode/
.vscode/

*.csv
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ install:
@if [ "$(shell which poetry)" = "" ]; then \
$(MAKE) install-poetry; \
fi
@$(MAKE) setup-poetry install-hooks
@$(MAKE) install-ersilia setup-poetry install-hooks

install-prod:
@if [ "$(shell which poetry)" = "" ]; then \
Expand Down
10 changes: 10 additions & 0 deletions config/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Optional

from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict


Expand All @@ -12,3 +15,10 @@ class DataLakeConfig(BaseSettings):
athena_database: str
athena_prediction_table: str
athena_request_table: str


class WorkerConfig(BaseModel):
git_sha: str
denominator: int # the total number of workers to split data over
numerator: int # the number assigned to this worker
sample: Optional[str] = None # sample size of reference library (in number of rows)
2 changes: 1 addition & 1 deletion notebooks/athena-poc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
22 changes: 21 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 5 additions & 20 deletions precalculator/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,20 @@

from config.app import DataLakeConfig

# CLI ensures that input is correctly formatted:
# | input key |
# on the
# presignurl -> a specific S3 bucket, object name is the request ID, prefix is the model ID
#
# s3://bucket/model-id/request-id.csv


class PredictionFetcher:
def __init__(self, config: DataLakeConfig, user_id: str, request_id: str, model_id: str, dev: bool = False):
self.config = config
self.user_id = user_id
self.request_id = request_id
# TODO: decide on multi model implementation, for now assume a list of 1 model ID
self.model_id = model_id
self.dev = dev

self.logger = logging.getLogger("PredictionFetcher")
self.logger.setLevel(logging.INFO)

if self.dev:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
logging.getLogger("botocore").setLevel(logging.WARNING)

def check_availability(self) -> str:
Expand All @@ -50,18 +43,10 @@ def fetch(self, path_to_input: str) -> pd.DataFrame:
input_df = self._read_input_data(path_to_input)

logger.info("writing input to athena")
try:
self._write_inputs_s3(input_df)
except Exception as e:
print(f"error {e}")
raise (e)
self._write_inputs_s3(input_df)

logger.info("fetching outputs from athena")
try:
output_df = self._read_predictions_from_s3()
except Exception as e:
print(f"error {e}")
raise (e)
output_df = self._read_predictions_from_s3()

return output_df

Expand Down
46 changes: 42 additions & 4 deletions precalculator/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Any

import pandas as pd
from pydantic import BaseModel, Field


class Prediction(BaseModel):
"""Dataclass to represent a single prediction"""

model_id: str
input_key: str
smiles: str
output: list[Any]
output: dict
model_id: str


class Metadata(BaseModel):
Expand All @@ -22,3 +21,42 @@ class Metadata(BaseModel):
pipeline_latest_start_time: int = Field(default=0)
pipeline_latest_duration: int = Field(default=0)
pipeline_meta_s3_uri: str = Field(default="")


class SchemaValidationError(Exception):
def __init__(self, errors: list[str]):
self.errors = errors
error_message = "\n".join(errors)
super().__init__(f"Schema validation failed with the following errors:\n{error_message}\n")


def validate_dataframe_schema(df: pd.DataFrame, model: BaseModel) -> None:
errors = []
schema = model.model_fields

for field_name, field in schema.items():
if field_name not in df.columns:
errors.append(f"Missing column: {field_name}")
else:
pandas_dtype = df[field_name].dtype
pydantic_type = field.annotation
if not _check_type_compatibility(pandas_dtype, pydantic_type):
errors.append(f"Column {field_name} has type {pandas_dtype}, expected {pydantic_type}")

for column in df.columns:
if column not in schema:
errors.append(f"Unexpected column: {column}")

if errors:
raise SchemaValidationError(errors)


def _check_type_compatibility(pandas_dtype, pydantic_type) -> bool: # noqa: ANN001
type_map = {
"object": [str, dict, list],
"int64": [int],
"float64": [float],
"bool": [bool],
"datetime64": [pd.Timestamp],
}
return pydantic_type in type_map.get(str(pandas_dtype)) # type: ignore
Loading

0 comments on commit 9443a74

Please sign in to comment.