Skip to content

Commit

Permalink
Add embedding_model as input param to Agent and ConversationalAgent
Browse files Browse the repository at this point in the history
  • Loading branch information
krohling committed Dec 19, 2023
1 parent 21354b9 commit f450e5f
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 11 deletions.
19 changes: 16 additions & 3 deletions bondai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
from typing import Dict, List, Tuple, Callable
from bondai.util import EventMixin, Runnable, load_local_resource
from bondai.tools import Tool, ResponseQueryTool
from bondai.models import LLM
from bondai.models import LLM, EmbeddingModel
from bondai.memory import MemoryManager
from bondai.prompt import JinjaPromptBuilder
from bondai.models.openai import OpenAILLM, OpenAIModelNames, get_total_cost
from bondai.models.openai import (
OpenAILLM,
OpenAIEmbeddingModel,
OpenAIModelNames,
get_total_cost,
)
from .conversation_member import ConversationMember
from .messages import AgentMessage, AgentMessageList, SystemMessage, ToolUsageMessage
from .compression import summarize_conversation, summarize_messages
Expand Down Expand Up @@ -56,6 +61,7 @@ class Agent(EventMixin, Runnable):
def __init__(
self,
llm: LLM | None = None,
embedding_model: EmbeddingModel | None = None,
tools: List[Tool] | None = None,
quiet: bool = True,
allowed_events: List[str] | None = None,
Expand Down Expand Up @@ -85,6 +91,10 @@ def __init__(

if llm is None:
llm = OpenAILLM(OpenAIModelNames.GPT4_0613)
if embedding_model is None:
embedding_model = OpenAIEmbeddingModel(
OpenAIModelNames.TEXT_EMBEDDING_ADA_002
)
if tools is None:
tools = []
if system_prompt_sections is None:
Expand All @@ -96,6 +106,7 @@ def __init__(
self._status: AgentStatus = AgentStatus.IDLE
self._messages = AgentMessageList(messages=messages)
self._llm: LLM = llm
self._embedding_model: EmbeddingModel = embedding_model
self._tools: List[Tool] = tools
self._quiet: bool = quiet
self._system_prompt_sections: List[Callable[[], str]] = system_prompt_sections
Expand Down Expand Up @@ -317,7 +328,9 @@ def _run_tool_loop(
last_error_message = None
local_messages = []
self._force_stop = False
response_query_tool = ResponseQueryTool()
response_query_tool = ResponseQueryTool(
llm=self._llm, embedding_model=self._embedding_model
)

def append_message(message):
if isinstance(message, SystemMessage):
Expand Down
10 changes: 6 additions & 4 deletions bondai/agents/conversational_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ExitConversationTool,
)
from bondai.prompt import JinjaPromptBuilder
from bondai.models.llm import LLM
from bondai.models import LLM, EmbeddingModel
from bondai.models.openai import OpenAILLM, OpenAIModelNames, get_total_cost
from .agent import (
Agent,
Expand Down Expand Up @@ -47,6 +47,7 @@ class ConversationalAgent(Agent, ConversationMember):
def __init__(
self,
llm: LLM | None = None,
embedding_model: EmbeddingModel | None = None,
tools: List[Tool] | None = None,
messages: List[AgentMessage] | None = None,
name: str = DEFAULT_AGENT_NAME,
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
Agent.__init__(
self,
llm=llm,
embedding_model=embedding_model,
quiet=quiet,
tools=tools,
messages=messages,
Expand Down Expand Up @@ -251,7 +253,7 @@ def validate_recipient(recipient_name: str):
"instructions": self.instructions,
"conversation_enabled": self._enable_conversation_tools
or self._enable_conversational_content_responses,
"allow_exit": self._enable_exit_conversation,
"enable_exit_conversation": self._enable_exit_conversation,
}

tool_result = self._run_tool_loop(
Expand Down Expand Up @@ -320,9 +322,9 @@ def to_dict(self) -> Dict:
data["persona"] = self._persona
data["persona_summary"] = self._persona_summary
data["instructions"] = self.instructions
data["allow_exit"] = self._enable_exit_conversation
data["quiet"] = self._quiet
data["enable_conversation_tools"] = self._enable_conversation_tools
data["enable_exit_conversation"] = self._enable_exit_conversation
data[
"enable_conversational_content_responses"
] = self._enable_conversational_content_responses
Expand Down Expand Up @@ -355,7 +357,7 @@ def from_dict(
persona=data["persona"],
persona_summary=data["persona_summary"],
instructions=data["instructions"],
allow_exit=data["allow_exit"],
enable_exit_conversation=data["enable_exit_conversation"],
quiet=data["quiet"],
enable_conversation_tools=data["enable_conversation_tools"],
enable_conversational_content_responses=data[
Expand Down
2 changes: 1 addition & 1 deletion bondai/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def build_agents(llm: LLM) -> GroupConversation:
tools=[AgentTool(task_execution_agent)],
enable_conversation_tools=False,
enable_conversational_content_responses=True,
allow_exit=False,
enable_exit_conversation=False,
memory_manager=MemoryManager(
core_memory_datasource=PersistentCoreMemoryDataSource(
file_path="./.memory/user_liason_core_memory.json",
Expand Down
2 changes: 1 addition & 1 deletion website/docs/api-spec/create-agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This API will create a Conversational Agent.

```json
{
"allow_exit":true,
"enable_exit_conversation":true,
"enable_conversation_tools":false,
"enable_conversational_content_responses":true,
"id":"34c2262b-1a9b-4ace-9b74-54e892ea59a2",
Expand Down
2 changes: 1 addition & 1 deletion website/docs/api-spec/get-agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This API returns the current state of an Agent.

```json
{
"allow_exit":true,
"enable_exit_conversation":true,
"enable_conversation_tools":false,
"enable_conversational_content_responses":true,
"id":"34c2262b-1a9b-4ace-9b74-54e892ea59a2",
Expand Down
2 changes: 1 addition & 1 deletion website/docs/api-spec/list-agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This API returns a list of all active Agents.
```json
[
{
"allow_exit":true,
"enable_exit_conversation":true,
"enable_conversation_tools":false,
"enable_conversational_content_responses":true,
"id":"34c2262b-1a9b-4ace-9b74-54e892ea59a2",
Expand Down

0 comments on commit f450e5f

Please sign in to comment.