Skip to content

Commit

Permalink
adding run.sh and main.py (#14)
Browse files Browse the repository at this point in the history
* adding

* changed service file

* Delete model/checkpoints/README.md

* Update main.py

* Update pack.py

* Create README.md

* Update pack.py

* Update run.sh

* Delete model/framework/input.csv

* Delete model/framework/output.csv

---------

Co-authored-by: Miquel Duran-Frigola <miquelduranfrigola@gmail.com>
Co-authored-by: Dhanshree Arora <DhanshreeA@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 3, 2024
1 parent eff0980 commit 49f0a80
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 16 deletions.
1 change: 1 addition & 0 deletions model/checkpoints/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Model Pretrained Parameters
39 changes: 39 additions & 0 deletions model/framework/code/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# imports
import os
import csv
import sys
from rdkit import Chem
from rdkit.Chem.Descriptors import MolWt

# parse arguments
input_file = sys.argv[1]
output_file = sys.argv[2]

# current file directory
root = os.path.dirname(os.path.abspath(__file__))

# my model
def molweight(smiles_list):
return [MolWt(Chem.MolFromSmiles(smi)) for smi in smiles_list]


# read SMILES from .csv file, assuming one column with header
with open(input_file, "r") as f:
reader = csv.reader(f)
next(reader) # skip header
smiles_list = [r[0] for r in reader]

# run model
outputs = molweight(smiles_list)

#check input and output have the same length
input_len = len(smiles_list)
output_len = len(outputs)
assert input_len == output_len

# write output in a .csv file
with open(output_file, "w") as f:
writer = csv.writer(f)
writer.writerow(["value"]) # header
for o in outputs:
writer.writerow([o])
1 change: 1 addition & 0 deletions model/framework/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python $1/code/main.py $2 $3
10 changes: 9 additions & 1 deletion pack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import os
from src.service import load_model
from src.service import Service
from src.service import CHECKPOINTS_BASEDIR, FRAMEWORK_BASEDIR

root = os.path.dirname(os.path.realpath(__file__))
mdl = load_model(
os.path.join(root, "model", FRAMEWORK_BASEDIR),
os.path.join(root, "model", CHECKPOINTS_BASEDIR),
)

service = Service()
service.pack("model", None)
service.pack("model", mdl)
service.save()
143 changes: 128 additions & 15 deletions src/service.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,140 @@
import random
import json
import collections
from rdkit import Chem
from rdkit.Chem import Descriptors
from typing import List

from bentoml import BentoService, api, artifacts
from bentoml.adapters import JsonInput
from bentoml.service.artifacts.common import JSONArtifact
from bentoml.types import JsonSerializable
from bentoml.service import BentoServiceArtifact

import pickle
import os
import shutil
import collections
import tempfile
import subprocess
import csv

CHECKPOINTS_BASEDIR = "checkpoints"
FRAMEWORK_BASEDIR = "framework"


def load_model(framework_dir, checkpoints_dir):
mdl = Model()
mdl.load(framework_dir, checkpoints_dir)
return mdl


def Float(x):
try:
return float(x)
except:
return None


class Model(object):
def __init__(self):
self.DATA_FILE = "_data.csv"
self.OUTPUT_FILE = "_output.csv"
self.RUN_FILE = "_run.sh"
self.LOG_FILE = "run.log"

def load(self, framework_dir, checkpoints_dir):
self.framework_dir = framework_dir
self.checkpoints_dir = checkpoints_dir

def set_checkpoints_dir(self, dest):
self.checkpoints_dir = os.path.abspath(dest)

def set_framework_dir(self, dest):
self.framework_dir = os.path.abspath(dest)

def run(self, input_list):
tmp_folder = tempfile.mkdtemp(prefix="eos-")
data_file = os.path.join(tmp_folder, self.DATA_FILE)
output_file = os.path.join(tmp_folder, self.OUTPUT_FILE)
log_file = os.path.join(tmp_folder, self.LOG_FILE)
with open(data_file, "w") as f:
f.write("input" + os.linesep)
for inp in input_list:
f.write(inp + os.linesep)
run_file = os.path.join(tmp_folder, self.RUN_FILE)
with open(run_file, "w") as f:
lines = [
"bash {0}/run.sh {0} {1} {2}".format(
self.framework_dir, data_file, output_file
)
]
f.write(os.linesep.join(lines))
cmd = "bash {0}".format(run_file)
with open(log_file, "w") as fp:
subprocess.Popen(
cmd, stdout=fp, stderr=fp, shell=True, env=os.environ
).wait()
with open(output_file, "r") as f:
reader = csv.reader(f)
h = next(reader)
R = []
for r in reader:
R += [
{"outcome": [Float(x) for x in r]}
] # <-- EDIT: Modify according to type of output (Float, String...)
meta = {"outcome": h}
result = {"result": R, "meta": meta}
shutil.rmtree(tmp_folder)
return result


class Artifact(BentoServiceArtifact):
def __init__(self, name):
super(Artifact, self).__init__(name)
self._model = None
self._extension = ".pkl"

def _copy_checkpoints(self, base_path):
src_folder = self._model.checkpoints_dir
dst_folder = os.path.join(base_path, "checkpoints")
if os.path.exists(dst_folder):
os.rmdir(dst_folder)
shutil.copytree(src_folder, dst_folder)

def _copy_framework(self, base_path):
src_folder = self._model.framework_dir
dst_folder = os.path.join(base_path, "framework")
if os.path.exists(dst_folder):
os.rmdir(dst_folder)
shutil.copytree(src_folder, dst_folder)

def _model_file_path(self, base_path):
return os.path.join(base_path, self.name + self._extension)

def pack(self, model):
self._model = model
return self

def load(self, path):
model_file_path = self._model_file_path(path)
model = pickle.load(open(model_file_path, "rb"))
model.set_checkpoints_dir(
os.path.join(os.path.dirname(model_file_path), "checkpoints")
)
model.set_framework_dir(
os.path.join(os.path.dirname(model_file_path), "framework")
)
return self.pack(model)

def get(self):
return self._model

def save(self, dst):
self._copy_checkpoints(dst)
self._copy_framework(dst)
pickle.dump(self._model, open(self._model_file_path(dst), "wb"))


@artifacts([JSONArtifact("model")])
@artifacts([Artifact("model")])
class Service(BentoService):
@api(input=JsonInput(), batch=True)
def run(self, input: List[JsonSerializable]):
"""
Calculate molecular weight
"""
input = input[0]
output = []
for inp in input:
mol = Chem.MolFromSmiles(inp["input"])
mw = Descriptors.MolWt(mol)
output += [{"mw": mw}]
input_list = [inp["input"] for inp in input]
output = self.artifacts.model.run(input_list)
return [output]

0 comments on commit 49f0a80

Please sign in to comment.