Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dysrama committed Feb 27, 2025
1 parent e48577e commit d15694f
Show file tree
Hide file tree
Showing 6 changed files with 391 additions and 363 deletions.
24 changes: 20 additions & 4 deletions go/genkit/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ import (
"strconv"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit/session"
)

type Chat struct {
Genkit *Genkit `json:"genkit,omitempty"`
Model ai.Model `json:"model,omitempty"` // The model to query
ThreadName string `json:"threadName,omitempty"` // The chats threadname
Session *Session `json:"session,omitempty"` // The chats session
Session *session.Session `json:"session,omitempty"` // The chats session
SystemText string `json:"systemtext,omitempty"` // Message sent to the model as system instructions
Prompt *ai.Prompt `json:"prompt,omitempty"` // Optional prompt
Input any `json:"input,omitempty"` // Optional input fields for the chat. This should be a struct, a pointer to a struct that matches the input schema, or a string.
Expand Down Expand Up @@ -62,7 +63,7 @@ func NewChat(ctx context.Context, g *Genkit, opts ...ChatOption) (chat *Chat, er
}

if chat.Session == nil {
s, err := NewSession(ctx)
s, err := session.New(ctx)
if err != nil {
return nil, err
}
Expand All @@ -83,7 +84,7 @@ func NewChat(ctx context.Context, g *Genkit, opts ...ChatOption) (chat *Chat, er
// included in the conversation before the history.
func (c *Chat) Send(ctx context.Context, msg any) (resp *ai.ModelResponse, err error) {
// Load history
data, err := c.Session.Store.Get(c.Session.ID)
data, err := c.Session.GetData()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -151,6 +152,21 @@ func (c *Chat) Send(ctx context.Context, msg any) (resp *ai.ModelResponse, err e
return resp, nil
}

// SendText sends a text message to the chat, generating a response from the AI model and return the text.
// It retrieves the chat history from the session store, adds the new message
// to the history, and sends the entire conversation to the AI model for
// generating a response. If a system message is set for the chat, it is
// included in the conversation before the history.
func (c *Chat) SendText(ctx context.Context, msgText string) (string, error) {
msg := ai.NewUserTextMessage(msgText)
resp, err := c.Send(ctx, msg)
if err != nil {
return "", err
}

return resp.Text(), nil
}

// UpdateMessages updates the messages for the chat.
func (c *Chat) UpdateMessages(thread string, msgs []*ai.Message) error {
c.Request.Messages = msgs
Expand All @@ -169,7 +185,7 @@ func WithModel(model ai.Model) ChatOption {
}

// WithSession sets a session for the chat.
func WithSession(session *Session) ChatOption {
func WithSession(session *session.Session) ChatOption {
return func(c *Chat) error {
if c.Session != nil {
return errors.New("genkit.WithSession: cannot set session more than once")
Expand Down
47 changes: 33 additions & 14 deletions go/genkit/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import (
"testing"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit/session"
"github.com/firebase/genkit/go/internal/base"
"github.com/invopop/jsonschema"
)

type HelloPromptInput struct {
UserName string
Name string
}

var chatGenkit, _ = New(nil)
Expand All @@ -41,12 +42,12 @@ func getNameTool(g *Genkit) *ai.ToolDef[struct{ Name string }, string] {
Name string
}) (string, error) {
// Set name in state
session, err := SessionFromContext(ctx)
session, err := session.FromContext(ctx)
if err != nil {
return "", err
}

err = session.UpdateState(input.Name)
err = session.UpdateState(input)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -133,7 +134,7 @@ func getChatPrompt(g *Genkit) *ai.Prompt {
}
prompt := fmt.Sprintf(
"Say hello to %s",
params.UserName)
params.Name)
return &ai.ModelRequest{Messages: []*ai.Message{
{Content: []*ai.Part{ai.NewTextPart(prompt)}},
}}, nil
Expand Down Expand Up @@ -211,7 +212,7 @@ func TestChatWithStreaming(t *testing.T) {
func TestChatWithOptions(t *testing.T) {
ctx := context.Background()

session, err := NewSession(ctx)
session, err := session.New(ctx)
if err != nil {
t.Fatal(err.Error())
}
Expand All @@ -233,7 +234,7 @@ func TestChatWithOptions(t *testing.T) {
t.Fatal(err.Error())
}

if chat.Session == nil || chat.Session.ID != session.ID {
if chat.Session == nil || chat.Session.GetID() != session.GetID() {
t.Errorf("session is not set")
}

Expand Down Expand Up @@ -267,7 +268,7 @@ func TestChatWithOptions(t *testing.T) {
func TestChatWithOptionsErrorHandling(t *testing.T) {
ctx := context.Background()

session, err := NewSession(ctx)
session, err := session.New(ctx)
if err != nil {
t.Fatal(err.Error())
}
Expand Down Expand Up @@ -436,7 +437,7 @@ func TestGetChatMessages(t *testing.T) {
func TestMultiChatSession(t *testing.T) {
ctx := context.Background()

session, err := NewSession(ctx)
session, err := session.New(ctx)
if err != nil {
t.Fatal(err.Error())
}
Expand Down Expand Up @@ -485,16 +486,20 @@ func TestMultiChatSession(t *testing.T) {
t.Errorf("got %q want %q", resp.Text(), want)
}

if len(session.SessionData.Threads) != 2 {
data, err := session.GetData()
if err != nil {
t.Fatal(err.Error())
}
if len(data.Threads) != 2 {
t.Errorf("session should have 2 threads")
}
}

func TestStateUpdate(t *testing.T) {
ctx := context.Background()

session, err := NewSession(ctx,
WithStateType("no name"),
session, err := session.New(ctx,
session.WithStateType(HelloPromptInput{}),
)
if err != nil {
t.Fatal(err.Error())
Expand Down Expand Up @@ -522,7 +527,11 @@ func TestStateUpdate(t *testing.T) {
t.Errorf("got %q want %q", resp.Text(), want)
}

if session.SessionData.State["state"] != "Earl" {
data, err := session.GetData()
if err != nil {
t.Fatal(err.Error())
}
if data.State["Name"] != "Earl" {
t.Error("session state not set")
}
}
Expand All @@ -536,7 +545,7 @@ func TestChatWithPrompt(t *testing.T) {
chatGenkit,
WithModel(chatModel),
WithPrompt(chatPrompt),
WithInput(HelloPromptInput{UserName: "Earl"}),
WithInput(HelloPromptInput{Name: "Earl"}),
)
if err != nil {
t.Fatal(err.Error())
Expand All @@ -552,8 +561,18 @@ func TestChatWithPrompt(t *testing.T) {
t.Errorf("got %q want %q", resp.Text(), want)
}

// Send text instead of messages
text, err := chat.SendText(ctx, "Send prompt")
if err != nil {
t.Fatal(err.Error())
}

if !strings.Contains(text, want) {
t.Errorf("got %q want %q", text, want)
}

// Rendered prompt to chat
mr, err := chatPrompt.Render(ctx, HelloPromptInput{UserName: "someone else"})
mr, err := chatPrompt.Render(ctx, HelloPromptInput{Name: "someone else"})
if err != nil {
t.Fatal(err.Error())
}
Expand Down
Loading

0 comments on commit d15694f

Please sign in to comment.