-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding * changed service file * Delete model/checkpoints/README.md * Update main.py * Update pack.py * Create README.md * Update pack.py * Update run.sh * Delete model/framework/input.csv * Delete model/framework/output.csv --------- Co-authored-by: Miquel Duran-Frigola <miquelduranfrigola@gmail.com> Co-authored-by: Dhanshree Arora <DhanshreeA@users.noreply.github.com>
- Loading branch information
1 parent
eff0980
commit 49f0a80
Showing
5 changed files
with
178 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Model Pretrained Parameters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# imports | ||
import os | ||
import csv | ||
import sys | ||
from rdkit import Chem | ||
from rdkit.Chem.Descriptors import MolWt | ||
|
||
# parse arguments | ||
input_file = sys.argv[1] | ||
output_file = sys.argv[2] | ||
|
||
# current file directory | ||
root = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
# my model | ||
def molweight(smiles_list): | ||
return [MolWt(Chem.MolFromSmiles(smi)) for smi in smiles_list] | ||
|
||
|
||
# read SMILES from .csv file, assuming one column with header | ||
with open(input_file, "r") as f: | ||
reader = csv.reader(f) | ||
next(reader) # skip header | ||
smiles_list = [r[0] for r in reader] | ||
|
||
# run model | ||
outputs = molweight(smiles_list) | ||
|
||
#check input and output have the same length | ||
input_len = len(smiles_list) | ||
output_len = len(outputs) | ||
assert input_len == output_len | ||
|
||
# write output in a .csv file | ||
with open(output_file, "w") as f: | ||
writer = csv.writer(f) | ||
writer.writerow(["value"]) # header | ||
for o in outputs: | ||
writer.writerow([o]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python $1/code/main.py $2 $3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,14 @@ | ||
import os | ||
from src.service import load_model | ||
from src.service import Service | ||
from src.service import CHECKPOINTS_BASEDIR, FRAMEWORK_BASEDIR | ||
|
||
root = os.path.dirname(os.path.realpath(__file__)) | ||
mdl = load_model( | ||
os.path.join(root, "model", FRAMEWORK_BASEDIR), | ||
os.path.join(root, "model", CHECKPOINTS_BASEDIR), | ||
) | ||
|
||
service = Service() | ||
service.pack("model", None) | ||
service.pack("model", mdl) | ||
service.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,140 @@ | ||
import random | ||
import json | ||
import collections | ||
from rdkit import Chem | ||
from rdkit.Chem import Descriptors | ||
from typing import List | ||
|
||
from bentoml import BentoService, api, artifacts | ||
from bentoml.adapters import JsonInput | ||
from bentoml.service.artifacts.common import JSONArtifact | ||
from bentoml.types import JsonSerializable | ||
from bentoml.service import BentoServiceArtifact | ||
|
||
import pickle | ||
import os | ||
import shutil | ||
import collections | ||
import tempfile | ||
import subprocess | ||
import csv | ||
|
||
CHECKPOINTS_BASEDIR = "checkpoints" | ||
FRAMEWORK_BASEDIR = "framework" | ||
|
||
|
||
def load_model(framework_dir, checkpoints_dir): | ||
mdl = Model() | ||
mdl.load(framework_dir, checkpoints_dir) | ||
return mdl | ||
|
||
|
||
def Float(x): | ||
try: | ||
return float(x) | ||
except: | ||
return None | ||
|
||
|
||
class Model(object): | ||
def __init__(self): | ||
self.DATA_FILE = "_data.csv" | ||
self.OUTPUT_FILE = "_output.csv" | ||
self.RUN_FILE = "_run.sh" | ||
self.LOG_FILE = "run.log" | ||
|
||
def load(self, framework_dir, checkpoints_dir): | ||
self.framework_dir = framework_dir | ||
self.checkpoints_dir = checkpoints_dir | ||
|
||
def set_checkpoints_dir(self, dest): | ||
self.checkpoints_dir = os.path.abspath(dest) | ||
|
||
def set_framework_dir(self, dest): | ||
self.framework_dir = os.path.abspath(dest) | ||
|
||
def run(self, input_list): | ||
tmp_folder = tempfile.mkdtemp(prefix="eos-") | ||
data_file = os.path.join(tmp_folder, self.DATA_FILE) | ||
output_file = os.path.join(tmp_folder, self.OUTPUT_FILE) | ||
log_file = os.path.join(tmp_folder, self.LOG_FILE) | ||
with open(data_file, "w") as f: | ||
f.write("input" + os.linesep) | ||
for inp in input_list: | ||
f.write(inp + os.linesep) | ||
run_file = os.path.join(tmp_folder, self.RUN_FILE) | ||
with open(run_file, "w") as f: | ||
lines = [ | ||
"bash {0}/run.sh {0} {1} {2}".format( | ||
self.framework_dir, data_file, output_file | ||
) | ||
] | ||
f.write(os.linesep.join(lines)) | ||
cmd = "bash {0}".format(run_file) | ||
with open(log_file, "w") as fp: | ||
subprocess.Popen( | ||
cmd, stdout=fp, stderr=fp, shell=True, env=os.environ | ||
).wait() | ||
with open(output_file, "r") as f: | ||
reader = csv.reader(f) | ||
h = next(reader) | ||
R = [] | ||
for r in reader: | ||
R += [ | ||
{"outcome": [Float(x) for x in r]} | ||
] # <-- EDIT: Modify according to type of output (Float, String...) | ||
meta = {"outcome": h} | ||
result = {"result": R, "meta": meta} | ||
shutil.rmtree(tmp_folder) | ||
return result | ||
|
||
|
||
class Artifact(BentoServiceArtifact): | ||
def __init__(self, name): | ||
super(Artifact, self).__init__(name) | ||
self._model = None | ||
self._extension = ".pkl" | ||
|
||
def _copy_checkpoints(self, base_path): | ||
src_folder = self._model.checkpoints_dir | ||
dst_folder = os.path.join(base_path, "checkpoints") | ||
if os.path.exists(dst_folder): | ||
os.rmdir(dst_folder) | ||
shutil.copytree(src_folder, dst_folder) | ||
|
||
def _copy_framework(self, base_path): | ||
src_folder = self._model.framework_dir | ||
dst_folder = os.path.join(base_path, "framework") | ||
if os.path.exists(dst_folder): | ||
os.rmdir(dst_folder) | ||
shutil.copytree(src_folder, dst_folder) | ||
|
||
def _model_file_path(self, base_path): | ||
return os.path.join(base_path, self.name + self._extension) | ||
|
||
def pack(self, model): | ||
self._model = model | ||
return self | ||
|
||
def load(self, path): | ||
model_file_path = self._model_file_path(path) | ||
model = pickle.load(open(model_file_path, "rb")) | ||
model.set_checkpoints_dir( | ||
os.path.join(os.path.dirname(model_file_path), "checkpoints") | ||
) | ||
model.set_framework_dir( | ||
os.path.join(os.path.dirname(model_file_path), "framework") | ||
) | ||
return self.pack(model) | ||
|
||
def get(self): | ||
return self._model | ||
|
||
def save(self, dst): | ||
self._copy_checkpoints(dst) | ||
self._copy_framework(dst) | ||
pickle.dump(self._model, open(self._model_file_path(dst), "wb")) | ||
|
||
|
||
@artifacts([JSONArtifact("model")]) | ||
@artifacts([Artifact("model")]) | ||
class Service(BentoService): | ||
@api(input=JsonInput(), batch=True) | ||
def run(self, input: List[JsonSerializable]): | ||
""" | ||
Calculate molecular weight | ||
""" | ||
input = input[0] | ||
output = [] | ||
for inp in input: | ||
mol = Chem.MolFromSmiles(inp["input"]) | ||
mw = Descriptors.MolWt(mol) | ||
output += [{"mw": mw}] | ||
input_list = [inp["input"] for inp in input] | ||
output = self.artifacts.model.run(input_list) | ||
return [output] |