Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go): Added generate-level middleware. #1949

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry,

func defineProgrammableModel(r *registry.Registry) *programmableModel {
pm := &programmableModel{r: r}
DefineModel(r, "default", "programmableModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
supports := &ModelInfoSupports{
Tools: true,
Multiturn: true,
}
DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb)
})
return pm
Expand Down
1 change: 0 additions & 1 deletion go/ai/embedder.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand Down
6 changes: 3 additions & 3 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ type ModelInfoSupports struct {

// A ModelRequest is a request to generate completions from a model.
type ModelRequest struct {
Config any `json:"config,omitempty"`
Context []any `json:"context,omitempty"`
Messages []*Message `json:"messages,omitempty"`
Config any `json:"config,omitempty"`
Context []*Document `json:"context,omitempty"`
Messages []*Message `json:"messages,omitempty"`
// Output describes the desired response format.
Output *ModelRequestOutput `json:"output,omitempty"`
ToolChoice ToolChoice `json:"toolChoice,omitempty"`
Expand Down
87 changes: 55 additions & 32 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@ type Model interface {
// Name returns the registry name of the model.
Name() string
// Generate applies the [Model] to provided request, handling tool requests and handles streaming.
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
}

type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]
// ModelFunc is a function that generates a model response.
type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// ModelMiddleware is middleware for model generate requests.
type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]
// ModelAction is an action for model generation.
type ModelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type generateAction = core.Action[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk]

type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error

Expand All @@ -44,7 +51,7 @@ type ToolConfig struct {

// DefineGenerateAction defines a utility generate action.
func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAction {
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{},
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, nil,
func(ctx context.Context, req *GenerateActionOptions, cb ModelStreamingCallback) (output *ModelResponse, err error) {
logger.FromContext(ctx).Debug("GenerateAction",
"input", fmt.Sprintf("%#v", req))
Expand All @@ -53,9 +60,10 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
"output", fmt.Sprintf("%#v", output),
"err", err)
}()

return tracing.RunInNewSpan(ctx, r.TracingState(), "generate", "util", false, req,
func(ctx context.Context, input *GenerateActionOptions) (*ModelResponse, error) {
model := LookupModel(r, "default", req.Model)
model := LookupModel(r, "", req.Model)
if model == nil {
return nil, fmt.Errorf("model %q not found", req.Model)
}
Expand Down Expand Up @@ -95,43 +103,43 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
ReturnToolRequests: req.ReturnToolRequests,
}

return model.Generate(ctx, r, modelReq, toolCfg, cb)
return model.Generate(ctx, r, modelReq, nil, toolCfg, cb)
})
}))
}

// DefineModel registers the given generate function as an action, and returns a
// [Model] that runs it.
// DefineModel registers the given generate function as an action, and returns a [Model] that runs it.
func DefineModel(
r *registry.Registry,
provider, name string,
metadata *ModelInfo,
generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error),
info *ModelInfo,
generate ModelFunc,
) Model {
metadataMap := map[string]any{}
if metadata == nil {
if info == nil {
// Always make sure there's at least minimal metadata.
metadata = &ModelInfo{
info = &ModelInfo{
Label: name,
Supports: &ModelInfoSupports{},
Versions: []string{},
}
}
if metadata.Label != "" {
metadataMap["label"] = metadata.Label
if info.Label != "" {
metadataMap["label"] = info.Label
}
supports := map[string]bool{
"media": metadata.Supports.Media,
"multiturn": metadata.Supports.Multiturn,
"systemRole": metadata.Supports.SystemRole,
"tools": metadata.Supports.Tools,
"media": info.Supports.Media,
"multiturn": info.Supports.Multiturn,
"systemRole": info.Supports.SystemRole,
"tools": info.Supports.Tools,
"toolChoice": info.Supports.ToolChoice,
}
metadataMap["supports"] = supports
metadataMap["versions"] = metadata.Versions
metadataMap["versions"] = info.Versions

generate = core.ChainMiddleware(ValidateSupport(name, info.Supports))(generate)

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate))
return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{"model": metadataMap}, generate))
}

// IsDefinedModel reports whether a model is defined.
Expand All @@ -158,6 +166,7 @@ type generateParams struct {
SystemPrompt *Message
MaxTurns int
ReturnToolRequests bool
Middleware []ModelMiddleware
}

// GenerateOption configures params of the Generate call.
Expand Down Expand Up @@ -224,10 +233,13 @@ func WithConfig(config any) GenerateOption {
}
}

// WithContext adds provided context to ModelRequest.
func WithContext(c ...any) GenerateOption {
// WithContext adds provided documents to ModelRequest.
func WithContext(docs ...*Document) GenerateOption {
return func(req *generateParams) error {
req.Request.Context = append(req.Request.Context, c...)
if req.Request.Context != nil {
return errors.New("generate.WithContext: cannot set context more than once")
}
req.Request.Context = docs
return nil
}
}
Expand Down Expand Up @@ -320,6 +332,17 @@ func WithToolChoice(toolChoice ToolChoice) GenerateOption {
}
}

// WithMiddleware adds middleware to the generate request.
func WithMiddleware(middleware ...ModelMiddleware) GenerateOption {
return func(req *generateParams) error {
if req.Middleware != nil {
return errors.New("generate.WithMiddleware: cannot set Middleware more than once")
}
req.Middleware = middleware
return nil
}
}

// Generate run generate request for this model. Returns ModelResponse struct.
func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelResponse, error) {
req := &generateParams{
Expand Down Expand Up @@ -368,7 +391,7 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
ReturnToolRequests: req.ReturnToolRequests,
}

return req.Model.Generate(ctx, r, req.Request, toolCfg, req.Stream)
return req.Model.Generate(ctx, r, req.Request, req.Middleware, toolCfg, req.Stream)
}

// validateModelVersion checks in the registry the action of the
Expand All @@ -386,7 +409,7 @@ func validateModelVersion(r *registry.Registry, v string, req *generateParams) (

// at the end, a Model is an action so type conversion is required
if a, ok := m.(*modelActionDef); ok {
if !(modelVersionSupported(v, (*modelAction)(a).Desc().Metadata)) {
if !(modelVersionSupported(v, (*ModelAction)(a).Desc().Metadata)) {
return false, fmt.Errorf("version %s not supported", v)
}
} else {
Expand Down Expand Up @@ -435,7 +458,7 @@ func GenerateData(ctx context.Context, r *registry.Registry, value any, opts ...
}

// Generate applies the [Action] to provided request, handling tool requests and handles streaming.
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) {
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) {
if m == nil {
return nil, errors.New("Generate called on a nil Model; check that all models are defined")
}
Expand All @@ -447,8 +470,6 @@ func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req
}
}

// TODO: Add warnings if the model does not support certain configuration options.

if req.Tools != nil {
toolNames := make(map[string]bool)
for _, tool := range req.Tools {
Expand All @@ -463,9 +484,11 @@ func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req
return nil, err
}

handler := core.ChainMiddleware(mw...)((*ModelAction)(m).Run)

currentTurn := 0
for {
resp, err := (*modelAction)(m).Run(ctx, req, cb)
resp, err := handler(ctx, req, cb)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -509,7 +532,7 @@ func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req
}
}

func (i *modelActionDef) Name() string { return (*modelAction)(i).Name() }
func (i *modelActionDef) Name() string { return (*ModelAction)(i).Name() }

// cloneMessage creates a deep copy of the provided Message.
func cloneMessage(m *Message) *Message {
Expand Down
65 changes: 59 additions & 6 deletions go/ai/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func TestGenerate(t *testing.T) {
},
},
Config: GenerationCommonConfig{Temperature: 1},
Context: []any{[]any{string("Banana")}},
Context: []*Document{&Document{Content: []*Part{NewTextPart("Banana")}}},
Output: &ModelRequestOutput{
Format: "json",
Schema: map[string]any{
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestGenerate(t *testing.T) {
Temperature: 1,
}),
WithHistory(NewUserTextMessage("banana"), NewModelTextMessage("yes, banana")),
WithContext([]any{"Banana"}),
WithContext(&Document{Content: []*Part{NewTextPart("Banana")}}),
WithOutputSchema(&GameCharacter{}),
WithTools(gablorkenTool),
WithStreaming(func(ctx context.Context, grc *ModelResponseChunk) error {
Expand Down Expand Up @@ -346,7 +346,13 @@ func TestGenerate(t *testing.T) {
},
)

interruptModel := DefineModel(r, "test", "interrupt", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
interruptModel := DefineModel(r, "test", "interrupt", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Expand Down Expand Up @@ -399,7 +405,13 @@ func TestGenerate(t *testing.T) {

t.Run("handles multiple parallel tool calls", func(t *testing.T) {
roundCount := 0
parallelModel := DefineModel(r, "test", "parallel", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
parallelModel := DefineModel(r, "test", "parallel", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
Expand Down Expand Up @@ -458,7 +470,13 @@ func TestGenerate(t *testing.T) {

t.Run("handles multiple rounds of tool calls", func(t *testing.T) {
roundCount := 0
multiRoundModel := DefineModel(r, "test", "multiround", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
multiRoundModel := DefineModel(r, "test", "multiround", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
Expand Down Expand Up @@ -520,7 +538,13 @@ func TestGenerate(t *testing.T) {
})

t.Run("exceeds maximum turns", func(t *testing.T) {
infiniteModel := DefineModel(r, "test", "infinite", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
infiniteModel := DefineModel(r, "test", "infinite", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Expand Down Expand Up @@ -550,6 +574,35 @@ func TestGenerate(t *testing.T) {
t.Errorf("unexpected error message: %v", err)
}
})

t.Run("applies middleware", func(t *testing.T) {
middlewareCalled := false
testMiddleware := func(next ModelFunc) ModelFunc {
return func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
middlewareCalled = true
req.Messages = append(req.Messages, NewUserTextMessage("middleware was here"))
return next(ctx, req, cb)
}
}

res, err := Generate(context.Background(), r,
WithModel(echoModel),
WithTextPrompt("test middleware"),
WithMiddleware(testMiddleware),
)
if err != nil {
t.Fatal(err)
}

if !middlewareCalled {
t.Error("middleware was not called")
}

expectedText := "test middlewaremiddleware was here"
if res.Text() != expectedText {
t.Errorf("got text %q, want %q", res.Text(), expectedText)
}
})
}

func TestModelVersion(t *testing.T) {
Expand Down
Loading
Loading