Skip to content

Commit

Permalink
Merge pull request #58 from oramasearch/feat/party-planner
Browse files Browse the repository at this point in the history
feat: adds party planner
  • Loading branch information
micheleriva authored Jan 23, 2025
2 parents 93750b2 + cab5d65 commit a27bd71
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 33 deletions.
10 changes: 10 additions & 0 deletions src/ai_server/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ service LLMService {
rpc GetEmbedding (EmbeddingRequest) returns (EmbeddingResponse);
rpc Chat (ChatRequest) returns (ChatResponse);
rpc ChatStream (ChatRequest) returns (stream ChatStreamResponse);
rpc PlannedAnswer (PlannedAnswerRequest) returns (stream PlannedAnswerResponse);
}

enum OramaModel {
Expand Down Expand Up @@ -63,6 +64,15 @@ message Embedding {
repeated float embeddings = 1; // Array of float values
}

// Request message for a planned answer
message PlannedAnswerRequest {
string input = 1; // The user input
}

message PlannedAnswerResponse {
string plan = 1;
}

// Request message for LLM calls
message ChatRequest {
LLMType model = 1; // Which LLM to use
Expand Down
46 changes: 25 additions & 21 deletions src/ai_server/service_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions src/ai_server/service_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def __init__(self, channel):
response_deserializer=service__pb2.ChatStreamResponse.FromString,
_registered_method=True,
)
self.PlannedAnswer = channel.unary_unary(
"/orama_ai_service.LLMService/PlannedAnswer",
request_serializer=service__pb2.PlannedAnswerRequest.SerializeToString,
response_deserializer=service__pb2.PlannedAnswerResponse.FromString,
_registered_method=True,
)


class LLMServiceServicer(object):
Expand Down Expand Up @@ -88,6 +94,12 @@ def ChatStream(self, request, context):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def PlannedAnswer(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")


def add_LLMServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -111,6 +123,11 @@ def add_LLMServiceServicer_to_server(servicer, server):
request_deserializer=service__pb2.ChatRequest.FromString,
response_serializer=service__pb2.ChatStreamResponse.SerializeToString,
),
"PlannedAnswer": grpc.unary_unary_rpc_method_handler(
servicer.PlannedAnswer,
request_deserializer=service__pb2.PlannedAnswerRequest.FromString,
response_serializer=service__pb2.PlannedAnswerResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler("orama_ai_service.LLMService", rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
Expand Down Expand Up @@ -240,3 +257,33 @@ def ChatStream(
metadata,
_registered_method=True,
)

@staticmethod
def PlannedAnswer(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_unary(
request,
target,
"/orama_ai_service.LLMService/PlannedAnswer",
service__pb2.PlannedAnswerRequest.SerializeToString,
service__pb2.PlannedAnswerResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True,
)
26 changes: 22 additions & 4 deletions src/ai_server/src/grpc/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import grpc
import logging
from json_repair import repair_json
from grpc_reflection.v1alpha import reflection
from concurrent.futures import ThreadPoolExecutor

Expand All @@ -15,13 +16,16 @@
ChatStreamResponse,
HealthCheckResponse,
LLMType,
PlannedAnswerResponse,
)
from src.prompts.party_planner import PartyPlannerActions


class LLMService(service_pb2_grpc.LLMServiceServicer):
def __init__(self, embeddings_service, models_manager):
self.embeddings_service = embeddings_service
self.models_manager = models_manager
self.party_planner_actions = PartyPlannerActions()

def CheckHealth(self, request, context):
return HealthCheckResponse(status="OK")
Expand All @@ -45,8 +49,6 @@ def GetEmbedding(self, request, context):
def Chat(self, request, context):
try:
model_name = LLMType.Name(request.model)
logging.info(f"Received Chat request with model: {model_name}, prompt: {request.prompt}")

history = (
[
{"role": ProtoRole.Name(message.role).lower(), "content": message.content}
Expand All @@ -67,8 +69,6 @@ def Chat(self, request, context):
def ChatStream(self, request, context):
try:
model_name = LLMType.Name(request.model)
logging.info(f"Received ChatStream request with model: {model_name}, prompt: {request.prompt}")

history = (
[
{"role": ProtoRole.Name(message.role).lower(), "content": message.content}
Expand All @@ -88,6 +88,24 @@ def ChatStream(self, request, context):
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Error in chat stream: {str(e)}")

def PlannedAnswer(self, request, context):
try:
model_name = "party_planner"
history = []
response = self.models_manager.chat(
model_id=model_name.lower(),
history=history,
prompt=request.input,
context=self.party_planner_actions.get_actions(),
)
return PlannedAnswerResponse(plan=repair_json(response))

except Exception as e:
logging.error(f"Error in PlannedAnswer: {e}", exc_info=True)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Error in planned answer stream: {str(e)}")
return PlannedAnswerResponse()


def serve(config, embeddings_service, models_manager):
logger = logging.getLogger(__name__)
Expand Down
47 changes: 42 additions & 5 deletions src/ai_server/src/prompts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,31 @@
"google_query_translator:user",
"answer:system",
"answer:user",
"party_planner:system",
"party_planner:user",
]

PROMPT_TEMPLATES: Dict[TemplateKey, PromptTemplate] = {
# ------------------------------
# Vision eCommerce model
# ------------------------------
"vision_ecommerce:system": "You are a product description assistant.",
"vision_ecommerce:user": lambda prompt, context: f"Describe the product shown in the image. Include details about its mood, colors, and potential use cases.\n\nImage: {prompt}",
"vision_ecommerce:user": lambda prompt, _context: f"Describe the product shown in the image. Include details about its mood, colors, and potential use cases.\n\nImage: {prompt}",
# ------------------------------
# Vision generic model
# ------------------------------
"vision_generic:system": "You are an image analysis assistant.",
"vision_generic:user": lambda prompt, context: f"Provide a detailed analysis of what is shown in this image, including key elements and their relationships.\n\nImage: {prompt}",
"vision_generic:user": lambda prompt, _context: f"Provide a detailed analysis of what is shown in this image, including key elements and their relationships.\n\nImage: {prompt}",
# ------------------------------
# Vision technical documentation model
# ------------------------------
"vision_tech_documentation:system": "You are a technical documentation analyzer.",
"vision_tech_documentation:user": lambda prompt, context: f"Analyze this technical documentation image, focusing on its key components and technical details.\n\nImage: {prompt}",
"vision_tech_documentation:user": lambda prompt, _context: f"Analyze this technical documentation image, focusing on its key components and technical details.\n\nImage: {prompt}",
# ------------------------------
# Vision code model
# ------------------------------
"vision_code:system": "You are a code analysis assistant.",
"vision_code:user": lambda prompt, context: f"Analyze the provided code block, explaining its functionality, implementation details, and intended purpose.\n\nCode: {prompt}",
"vision_code:user": lambda prompt, _context: f"Analyze the provided code block, explaining its functionality, implementation details, and intended purpose.\n\nCode: {prompt}",
# ------------------------------
# Google Query Translator model
# ------------------------------
Expand All @@ -43,7 +45,7 @@
'Your reply must be in the following format: {"query": "<translated_query>"}. As you can see, the translated query must be a JSON object with a single key, \'query\', whose value is the translated query. '
"Always reply with the most relevant and concise query possible in a valid JSON format, and nothing more."
),
"google_query_translator:user": lambda query, context: f"### Query\n{query}\n\n### Translated Query\n",
"google_query_translator:user": lambda query, _context: f"### Query\n{query}\n\n### Translated Query\n",
# ------------------------------
# Answer model
# ------------------------------
Expand Down Expand Up @@ -72,4 +74,39 @@
"""
),
"answer:user": lambda context, question: f"### Context\n{context}\n\n### Question\n{question}\n\n",
# ------------------------------
# Party planner
# ------------------------------
"party_planner:system": (
"""
You are an AI action planner. Given a set of allowed actions and user input, output a minimal sequence of actions to achieve the desired outcome.
### Input Format
actions: Array of allowed action names and their descriptions
input: User's desired outcome
### Output Format
JSON object with array of ordered steps:
{
"actions": [
{
"step": "action_name",
"description": "Specific description of how to apply this action"
}
]
}
### Constraints
- Only use actions from the provided allowed set
- Minimize number of steps
- Each step must move toward the goal
- Return error object if goal is impossible with given actions
### Example
Input: "Can you give me an example of how my data has to look when using the standard getExpandedRowModel() function?"
Actions: ["OPTIMIZE_QUERY", "PERFORM_ORAMA_SEARCH", "CREATE_CODE", "SUMMARIZE_FINDINGS", "GIVE_REPLY"]
Output: {"actions":[{ "step": "OPTIMIZE_QUERY", "description": "Optimize query into a more search-friendly query" }, { "step": "PERFORM_SEARCH", "description": "Use optimized query to perform search in the index" }, { "step": "CREATE_CODE", "description": "Craft code examples about using getExpandedRowModel() function" }, { "step": "SUMMARIZE_FINDINGS", "description": "Summarize the findings from the research and code generation" }]}
"""
),
"party_planner:user": lambda input, actions: f"### Input\n{input}\n\n### Actions\n{actions}",
}
Loading

0 comments on commit a27bd71

Please sign in to comment.