Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
k82cn committed Feb 16, 2025
1 parent 54e030f commit 84bc7e4
Show file tree
Hide file tree
Showing 38 changed files with 3,395 additions and 916 deletions.
45 changes: 45 additions & 0 deletions examples/pypi/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3

# Copyright 2025 The Flame Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flame
import argparse

parser = argparse.ArgumentParser(description='Flame Pi Python Example.')
parser.add_argument('-n', '--task_num', type=int, help="The total number of tasks in the session.")
parser.add_argument('-i', '--task_input', type=int, help="The input of each task to calculate Pi.")
args = parser.parse_args()

area = 0.0


def get_circle_area(task):
global area
area += float(task.output)


conn = flame.connect("127.0.0.1:8080")
ssn = conn.create_session(application="pi", slots=1)

# Convert args.task_input into bytes type.
task_input = str(args.task_input).encode()
task_inputs = [task_input] * args.task_num

# Submit all task inputs to Flame, and wait for the result.
ssn.run_all_tasks(task_inputs=task_inputs, on_completed=get_circle_area)

# Calculate the Pi.
pi = 4 * area / (args.task_input * args.task_num)

print("pi = 4*({}/{}) = {}".format(area, args.task_input * args.task_num, pi))

ssn.close()
17 changes: 10 additions & 7 deletions sdk/python/rpc/__init__.py → examples/pypi/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023 The Flame Authors.

# Copyright 2025 The Flame Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -9,12 +10,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["frontend_pb2", "frontend_pb2_grpc", "types_pb2", "types_pb2_grpc"]
import flame

def PiService(flame.FlameService):
def __init__(self):
self.area = 0.0


__path__ = ["."]

import frontend_pb2
import frontend_pb2_grpc
import types_pb2
import types_pb2_grpc

# Start Flame Pi Service
flame.start_service(PiService)
1 change: 0 additions & 1 deletion sdk/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@
# limitations under the License.

__all__ = ["flame"]
__path__ = ["rpc"]
Binary file modified sdk/python/__pycache__/flame.cpython-312.pyc
Binary file not shown.
65 changes: 64 additions & 1 deletion sdk/python/flame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from enum import Enum
from concurrent import futures
from urllib.parse import urlparse
import grpc
import logging

from rpc import *
import frontend_pb2_grpc
import frontend_pb2
import types_pb2
import shim_pb2_grpc
import shim_pb2

def connect(addr):
channel = grpc.insecure_channel(addr)
Expand Down Expand Up @@ -83,3 +91,58 @@ def __init__(self, task):
self.input = task.spec.input
self.output = task.spec.output
self.state = TaskState(task.status.state)

class TaskInput:
pass

class TaskOutput:
pass

class CommonData:
pass

class ApplicationContext:
def __init__(self, app_context):
self.name = app_context.name

class SessonContext:
def __init__(self, ssn_context):
self.session_id = ssn_context.id,
self.application = ApplicationContext(ssn_context.application),
self.common_data = ssn_context.common_data,

class TaskContext:
pass


class FlameService:
def on_session_enter(self, ssn_context):
pass

def on_session_enter(self):
pass

def on_task_invoke(self, task_context) -> TaskOutput:
pass

class GrpcShimService(shim_pb2_grpc.GrpcShimServicer):
def __init__(self, service):
self.service = service

def OnSessionEnter(self, ctx):
ssn_ctx = SessonContext(ctx)
self.service.on_session_enter(ssn_ctx)



def start_service(service):
log = logging.getLogger(__name__)
url = os.environ['FLAME_SERVICE_MANAGER']
o = urlparse(url)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
shim_pb2_grpc.add_GrpcShimServicer_to_server(GrpcShimService(service), server)
server.add_insecure_port("[::]:" + o.port)
log.info("The Flame service was started at " + url)

server.start()
server.wait_for_termination()
10 changes: 10 additions & 0 deletions sdk/python/flame/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__all__ = ["client", "service"]


import os
import sys

script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "."))
sys.path.append(os.path.join(script_dir, "rpc"))

85 changes: 85 additions & 0 deletions sdk/python/flame/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2023 The Flame Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
import grpc

from rpc import *

def connect(addr):
channel = grpc.insecure_channel(addr)
return Connection(channel)

class Connection:
def __init__(self, channel):
self.channel = channel

def create_session(self, *, application, slots):
stub = frontend_pb2_grpc.FrontendStub(self.channel)
spec = types_pb2.SessionSpec(application=application, slots=slots)
req = frontend_pb2.CreateSessionRequest(session=spec)
ssn = stub.CreateSession(req)
return Session(stub, ssn)


class SessionState(Enum):
Open = 0
Closed = 1


class Session:
def __init__(self, stub, ssn):
self.stub = stub
self.id = ssn.metadata.id

def create_task(self, task_input):
spec = types_pb2.TaskSpec(input=task_input, session_id=self.id)
task = self.stub.CreateTask(frontend_pb2.CreateTaskRequest(task=spec))
return Task(task)

def get_task(self, task_id):
req = frontend_pb2.GetTaskRequest(task_id=task_id, session_id=self.id)
task = self.stub.GetTask(req)
return Task(task)

def watch_task(self, *, task_id, on_completed=None, on_error=None):
req = frontend_pb2.WatchTaskRequest(task_id=task_id, session_id=self.id)
tasks = self.stub.WatchTask(req)
for task in tasks:
state = TaskState(task.status.state)
if state == TaskState.Succeed and on_completed != None:
on_completed(Task(task))

def run_all_tasks(self, *, task_inputs, on_completed=None, on_error=None):
tasks = []
for task_input in task_inputs:
tasks.append(self.create_task(task_input))
for task in tasks:
self.watch_task(task_id=task.id, on_completed=on_completed, on_error=on_error)

def close(self):
self.stub.CloseSession(frontend_pb2.CloseSessionRequest(session_id=self.id))


class TaskState(Enum):
Pending = 0
Running = 1
Succeed = 2
Failed = 3


class Task:
def __init__(self, task):
self.id = task.metadata.id
self.session_id = task.spec.session_id
self.input = task.spec.input
self.output = task.spec.output
self.state = TaskState(task.status.state)
41 changes: 41 additions & 0 deletions sdk/python/flame/flame_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python

# Copyright 2023 The Flame Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import client as flame
import unittest

class FlameTestCase(unittest.TestCase):
def test_session_creation(self):
conn = flame.connect("127.0.0.1:8080")
ssn = conn.create_session(application="flmexec", slots=1)
self.assertIsNotNone(ssn)
ssn.close()

def test_multi_session_creation(self):
conn = flame.connect("127.0.0.1:8080")
for i in range(5):
ssn = conn.create_session(application="flmexec", slots=1)
self.assertIsNotNone(ssn)
ssn.close()

def test_session_creation_with_tasks(self):
conn = flame.connect("127.0.0.1:8080")
ssn = conn.create_session(application="flmexec", slots=1)
self.assertIsNotNone(ssn)
task = ssn.create_task(None)
ssn.watch_task(task_id = task.id)
ssn.close()


if __name__ == '__main__':
unittest.main()
22 changes: 22 additions & 0 deletions sdk/python/flame/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# import os
# import sys
# script_dir = os.path.dirname(os.path.realpath(__file__))

# sys.path.append(os.path.join(script_dir, "."))


__all__ = ["types_pb2", "types_pb2_grpc", "frontend_pb2", "frontend_pb2_grpc", "shim_pb2", "shim_pb2_grpc"]
# __path__ = [".", "rpc"]

import os
import sys

script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "."))

# from .types_pb2 import types_pb2
# from .types_pb2_grpc import types_pb2_grpc
# from .frontend_pb2 import frontend_pb2
# from .frontend_pb2_grpc import frontend_pb2_grpc
# from .shim_pb2 import shim_pb2
# from .shim_pb2_grpc import shim_pb2_grpc
67 changes: 67 additions & 0 deletions sdk/python/flame/rpc/frontend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 84bc7e4

Please sign in to comment.