diff --git a/bondai/agents/agent.py b/bondai/agents/agent.py index 709fc55..72bf047 100644 --- a/bondai/agents/agent.py +++ b/bondai/agents/agent.py @@ -6,10 +6,15 @@ from typing import Dict, List, Tuple, Callable from bondai.util import EventMixin, Runnable, load_local_resource from bondai.tools import Tool, ResponseQueryTool -from bondai.models import LLM +from bondai.models import LLM, EmbeddingModel from bondai.memory import MemoryManager from bondai.prompt import JinjaPromptBuilder -from bondai.models.openai import OpenAILLM, OpenAIModelNames, get_total_cost +from bondai.models.openai import ( + OpenAILLM, + OpenAIEmbeddingModel, + OpenAIModelNames, + get_total_cost, +) from .conversation_member import ConversationMember from .messages import AgentMessage, AgentMessageList, SystemMessage, ToolUsageMessage from .compression import summarize_conversation, summarize_messages @@ -56,6 +61,7 @@ class Agent(EventMixin, Runnable): def __init__( self, llm: LLM | None = None, + embedding_model: EmbeddingModel | None = None, tools: List[Tool] | None = None, quiet: bool = True, allowed_events: List[str] | None = None, @@ -85,6 +91,10 @@ def __init__( if llm is None: llm = OpenAILLM(OpenAIModelNames.GPT4_0613) + if embedding_model is None: + embedding_model = OpenAIEmbeddingModel( + OpenAIModelNames.TEXT_EMBEDDING_ADA_002 + ) if tools is None: tools = [] if system_prompt_sections is None: @@ -96,6 +106,7 @@ def __init__( self._status: AgentStatus = AgentStatus.IDLE self._messages = AgentMessageList(messages=messages) self._llm: LLM = llm + self._embedding_model: EmbeddingModel = embedding_model self._tools: List[Tool] = tools self._quiet: bool = quiet self._system_prompt_sections: List[Callable[[], str]] = system_prompt_sections @@ -317,7 +328,9 @@ def _run_tool_loop( last_error_message = None local_messages = [] self._force_stop = False - response_query_tool = ResponseQueryTool() + response_query_tool = ResponseQueryTool( + llm=self._llm, embedding_model=self._embedding_model + ) def append_message(message): if isinstance(message, SystemMessage): diff --git a/bondai/agents/conversational_agent.py b/bondai/agents/conversational_agent.py index 8744a11..a82da34 100644 --- a/bondai/agents/conversational_agent.py +++ b/bondai/agents/conversational_agent.py @@ -13,7 +13,7 @@ ExitConversationTool, ) from bondai.prompt import JinjaPromptBuilder -from bondai.models.llm import LLM +from bondai.models import LLM, EmbeddingModel from bondai.models.openai import OpenAILLM, OpenAIModelNames, get_total_cost from .agent import ( Agent, @@ -47,6 +47,7 @@ class ConversationalAgent(Agent, ConversationMember): def __init__( self, llm: LLM | None = None, + embedding_model: EmbeddingModel | None = None, tools: List[Tool] | None = None, messages: List[AgentMessage] | None = None, name: str = DEFAULT_AGENT_NAME, @@ -82,6 +83,7 @@ def __init__( Agent.__init__( self, llm=llm, + embedding_model=embedding_model, quiet=quiet, tools=tools, messages=messages, @@ -251,7 +253,7 @@ def validate_recipient(recipient_name: str): "instructions": self.instructions, "conversation_enabled": self._enable_conversation_tools or self._enable_conversational_content_responses, - "allow_exit": self._enable_exit_conversation, + "enable_exit_conversation": self._enable_exit_conversation, } tool_result = self._run_tool_loop( @@ -320,9 +322,9 @@ def to_dict(self) -> Dict: data["persona"] = self._persona data["persona_summary"] = self._persona_summary data["instructions"] = self.instructions - data["allow_exit"] = self._enable_exit_conversation data["quiet"] = self._quiet data["enable_conversation_tools"] = self._enable_conversation_tools + data["enable_exit_conversation"] = self._enable_exit_conversation data[ "enable_conversational_content_responses" ] = self._enable_conversational_content_responses @@ -355,7 +357,7 @@ def from_dict( persona=data["persona"], persona_summary=data["persona_summary"], instructions=data["instructions"], - allow_exit=data["allow_exit"], + enable_exit_conversation=data["enable_exit_conversation"], quiet=data["quiet"], enable_conversation_tools=data["enable_conversation_tools"], enable_conversational_content_responses=data[ diff --git a/bondai/cli/cli.py b/bondai/cli/cli.py index f99f7d4..06a95c5 100644 --- a/bondai/cli/cli.py +++ b/bondai/cli/cli.py @@ -105,7 +105,7 @@ def build_agents(llm: LLM) -> GroupConversation: tools=[AgentTool(task_execution_agent)], enable_conversation_tools=False, enable_conversational_content_responses=True, - allow_exit=False, + enable_exit_conversation=False, memory_manager=MemoryManager( core_memory_datasource=PersistentCoreMemoryDataSource( file_path="./.memory/user_liason_core_memory.json", diff --git a/website/docs/api-spec/create-agent.md b/website/docs/api-spec/create-agent.md index 1483455..8d2a19e 100644 --- a/website/docs/api-spec/create-agent.md +++ b/website/docs/api-spec/create-agent.md @@ -12,7 +12,7 @@ This API will create a Conversational Agent. ```json { - "allow_exit":true, + "enable_exit_conversation":true, "enable_conversation_tools":false, "enable_conversational_content_responses":true, "id":"34c2262b-1a9b-4ace-9b74-54e892ea59a2", diff --git a/website/docs/api-spec/get-agent.md b/website/docs/api-spec/get-agent.md index 22ca1ed..d1a2a8d 100644 --- a/website/docs/api-spec/get-agent.md +++ b/website/docs/api-spec/get-agent.md @@ -12,7 +12,7 @@ This API returns the current state of an Agent. ```json { - "allow_exit":true, + "enable_exit_conversation":true, "enable_conversation_tools":false, "enable_conversational_content_responses":true, "id":"34c2262b-1a9b-4ace-9b74-54e892ea59a2", diff --git a/website/docs/api-spec/list-agents.md b/website/docs/api-spec/list-agents.md index bce7728..458d957 100644 --- a/website/docs/api-spec/list-agents.md +++ b/website/docs/api-spec/list-agents.md @@ -13,7 +13,7 @@ This API returns a list of all active Agents. ```json [ { - "allow_exit":true, + "enable_exit_conversation":true, "enable_conversation_tools":false, "enable_conversational_content_responses":true, "id":"34c2262b-1a9b-4ace-9b74-54e892ea59a2",