Skip to content

Commit 35531e8

Browse files
authored
feat: add OpenAI assistant api integration (#63)
* feat: add openai assistant api integration * add EventHandler * change api * fix hitl
1 parent 859e969 commit 35531e8

File tree

8 files changed

+178
-7
lines changed

8 files changed

+178
-7
lines changed

examples/github_notifier/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from openai import OpenAI
22
from npiai.core import Agent
3-
from npiai.core.hitl import ConsoleHITLHandler
3+
from npiai.tools.hitl.console import ConsoleHandler
44
from npiai.app.github import GitHub
55
from npiai.app.google import Gmail
66

@@ -11,7 +11,7 @@ def main():
1111
prompt='You are a Github Notifier that informs users when a new issue or pull request is open.',
1212
description='Github Notifier that informs users when a new issue or pull request is open',
1313
llm=OpenAI(),
14-
hitl_handler=ConsoleHITLHandler()
14+
hitl_handler=ConsoleHandler()
1515
)
1616

1717
agent.use(GitHub(), Gmail())

examples/openai/assistant.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
from openai import OpenAI
3+
from npiai.app import Gmail, GitHub
4+
from npiai.tools.hitl.console import ConsoleHandler
5+
from npiai.core.toolset import ToolSet
6+
from npiai.integration.oai import EventHandler
7+
8+
9+
if __name__ == "__main__":
10+
client = OpenAI(api_key="xxxxx")
11+
ts = ToolSet(
12+
llm=client,
13+
hitl_handler=ConsoleHandler(),
14+
tools=[
15+
GitHub(access_token="xxxxx"),
16+
],
17+
)
18+
19+
assistant = client.beta.assistants.create(
20+
name="GitHub Issue Assistant",
21+
instructions="You are an Assistant can maintain issue comment for repo npi-ai/npi.",
22+
model="gpt-4o",
23+
tools=ts.openai(),
24+
)
25+
26+
thread = client.beta.threads.create()
27+
message = client.beta.threads.messages.create(
28+
thread_id=thread.id,
29+
role="user",
30+
content="what's title of issue #27 of repo npi-ai/npi?",
31+
)
32+
33+
def stream_handler(run_id: str, stream):
34+
for text in stream.text_deltas:
35+
print(text, end="", flush=True)
36+
37+
eh = EventHandler(toolset=ts, llm=client, thread_id=thread.id, stream_handler=stream_handler)
38+
with client.beta.threads.runs.stream(
39+
thread_id=thread.id,
40+
assistant_id=assistant.id,
41+
event_handler=eh,
42+
) as _stream:
43+
_stream.until_done()

sdk/python/npiai/app/github.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from npiai.core.base import App
22
from npiai_proto import api_pb2
3+
from typing_extensions import override
34

45

56
class GitHub(App):
67

7-
def __init__(self, npi_endpoint: str = None):
8+
def __init__(self, access_token: str, npi_endpoint: str = None):
89
super().__init__(
910
app_name="github",
1011
app_type=api_pb2.GITHUB,
1112
endpoint=npi_endpoint,
1213
)
14+
self.access_token = access_token
15+
16+
@override
17+
def authorize(self):
18+
super()._authorize(credentials={"access_token": self.access_token})

sdk/python/npiai/core/base.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@ def __init__(
3333
endpoint: str = "localhost:9140",
3434
hitl_handler: hitl.HITLHandler = None,
3535
npi_token: str = None,
36+
insecure: bool = True,
3637
):
3738
self.__app_name = app_name
3839
if endpoint is None:
3940
endpoint = "localhost:9140"
4041
self.__npi_endpoint = endpoint
4142
self.__app_type = app_type
42-
43-
channel = grpc.secure_channel(target=self.__npi_endpoint, credentials=grpc.ssl_channel_credentials())
43+
if insecure:
44+
channel = grpc.insecure_channel(self.__npi_endpoint)
45+
else:
46+
channel = grpc.secure_channel(target=self.__npi_endpoint, credentials=grpc.ssl_channel_credentials())
4447
self.stub = api_pb2_grpc.AppServerStub(channel)
4548
self.hitl_handler = hitl_handler
4649
self.__npi_token = npi_token
@@ -104,15 +107,24 @@ def chat(self, msg: str) -> str:
104107
case api_pb2.ResponseCode.ACTION_REQUIRED:
105108
resp = self.stub.Chat(request=self.__call_human(resp), metadata=self.__get_metadata())
106109
case _:
110+
logger.error(f'[{self.__app_name}]: Error: failed to call function, unknown response code {resp.code}')
107111
raise Exception("Error: failed to call function")
108112

109113
def hitl(self, handler: hitl.HITLHandler):
110114
self.hitl_handler = handler
111115

112-
# @abstractmethod
113-
def authorize(self, **kwargs):
116+
def authorize(self):
114117
pass
115118

119+
def _authorize(self, credentials: dict[str, str]):
120+
self.stub.Authorize(
121+
request=api_pb2.AuthorizeRequest(
122+
type=self.__app_type,
123+
credentials=credentials,
124+
),
125+
)
126+
logger.info(f'[{self.__app_name}]: Authorized')
127+
116128
def __call_human(self, resp: api_pb2.Response) -> api_pb2.Request:
117129
human_resp = self.hitl_handler.handle(
118130
hitl.convert_to_hitl_request(

sdk/python/npiai/core/toolset.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from npiai.core.base import App
2+
from npiai.core.hitl import HITLHandler
3+
from openai import Client
4+
5+
from npiai.tools.hitl import EmptyHandler
6+
import json
7+
from typing import List
8+
9+
10+
class ToolSet:
11+
12+
def __init__(self, tools: List[App], llm: Client, hitl_handler: HITLHandler = None, ):
13+
self._llm = llm
14+
self._hitl_handler = hitl_handler
15+
if self._hitl_handler is None:
16+
self._hitl_handler = EmptyHandler()
17+
18+
if len(tools) == 0:
19+
raise Exception("At least one tool is required.")
20+
21+
self._tools = {}
22+
for tool in tools:
23+
tool.authorize()
24+
tool.hitl(self._hitl_handler)
25+
self._tools[tool.tool_name()] = tool
26+
27+
def openai(self):
28+
tools = []
29+
for tool in self._tools.values():
30+
tools.append(tool.schema())
31+
return tools
32+
33+
def call(self, body) -> str:
34+
tool = self._tools.get(body.function.name)
35+
if tool is None:
36+
return "Tool not found"
37+
params = json.loads(body.function.arguments)
38+
39+
return tool.chat(params['message'])

sdk/python/npiai/integration/oai.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing_extensions import override
2+
3+
from openai import AssistantEventHandler, Client
4+
from npiai.core.toolset import ToolSet
5+
6+
7+
class EventHandler(AssistantEventHandler):
8+
def __init__(self, toolset: ToolSet, llm: Client, thread_id: str, stream_handler=None):
9+
super().__init__()
10+
self.toolset = toolset
11+
self.llm = llm
12+
self.thread_id = thread_id
13+
self.stream_handler = stream_handler
14+
15+
@override
16+
def on_event(self, event):
17+
# Retrieve events that are denoted with 'requires_action'
18+
# since these will have our tool_calls
19+
if event.event == 'thread.run.requires_action':
20+
run_id = event.data.id # Retrieve the run ID from the event data
21+
self.handle_requires_action(event.data, run_id)
22+
23+
def handle_requires_action(self, data, run_id):
24+
tool_outputs = []
25+
26+
for tool in data.required_action.submit_tool_outputs.tool_calls:
27+
tool_outputs.append({
28+
"tool_call_id": tool.id,
29+
"output": self.toolset.call(tool),
30+
})
31+
32+
# Submit all tool_outputs at the same time
33+
self.submit_tool_outputs(tool_outputs, run_id)
34+
35+
def submit_tool_outputs(self, tool_outputs, run_id):
36+
# Use the submit_tool_outputs_stream helper
37+
with self.llm.beta.threads.runs.submit_tool_outputs_stream(
38+
thread_id=self.thread_id,
39+
run_id=run_id,
40+
tool_outputs=tool_outputs,
41+
event_handler=EventHandler(self.toolset, self.llm, self.thread_id),
42+
) as stream:
43+
self.stream_handler(run_id, stream)
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .console import ConsoleHandler
2+
from .empty import EmptyHandler
3+
from .twilio import TwilioHandler
4+
5+
__all__ = ['ConsoleHandler', 'EmptyHandler', 'TwilioHandler']

sdk/python/npiai/tools/hitl/empty.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from npiai.core.hitl import (
2+
HITLHandler, HITLRequest, HITLResponse, ActionRequestCode, ActionResultCode, ACTION_APPROVED, ACTION_DENIED)
3+
4+
from npiai_proto import api_pb2
5+
6+
7+
class EmptyHandler(HITLHandler):
8+
def handle(self, req: HITLRequest) -> HITLResponse:
9+
match req.code:
10+
case ActionRequestCode.INFORMATION:
11+
return HITLResponse(
12+
code=ActionResultCode.SUCCESS,
13+
msg="No human can response your help",
14+
)
15+
case ActionRequestCode.CONFIRMATION:
16+
return HITLResponse(
17+
code=ActionResultCode.APPROVED,
18+
msg="automatically approved",
19+
)
20+
return ACTION_DENIED
21+
22+
def type(self) -> api_pb2.ActionType:
23+
return api_pb2.ActionType.CONSOLE

0 commit comments

Comments
 (0)