Skip to content

Commit

Permalink
Refactor structured output chain
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Dec 16, 2023
1 parent d7d851f commit 2bf4a97
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 35 deletions.
7 changes: 7 additions & 0 deletions chain/chat_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ type ChatModelOptions struct {
// CallbackOptions contains options for the chain callbacks.
*schema.CallbackOptions

// ForceFunctionCall forced the model to call the first function
ForceFunctionCall bool

// OutputKey is the key to access the output value containing the ChatModel response summary.
OutputKey string
}
Expand Down Expand Up @@ -73,7 +76,11 @@ func (c *ChatModel) Call(ctx context.Context, inputs schema.ChainValues, optFns
}

result, err := model.GeneratePrompt(ctx, c.chatModel, pv, func(o *model.Options) {
o.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
o.ParentRunID = opts.CallbackManger.RunID()
o.Stop = opts.Stop
o.Functions = c.functions
o.ForceFunctionCall = c.opts.ForceFunctionCall
})
if err != nil {
return nil, err
Expand Down
32 changes: 23 additions & 9 deletions chain/structured_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ import (
// Compile time check to ensure StructuredOutput satisfies the Chain interface.
var _ schema.Chain = (*StructuredOutput)(nil)

const defaultStructuredOutputTemplate = `Extract and save the relevant entities mentioned in the following passage.
Passage:
{{.input}}`

// OutputCandidate represents a candidate for structured output containing a name,
// description, and data of any struct type.
type OutputCandidate struct {
Expand All @@ -28,9 +33,18 @@ type OutputCandidate struct {
// StructuredOutputOptions contains options for configuring the StructuredOutput chain.
type StructuredOutputOptions struct {
*schema.CallbackOptions
Prompt prompt.ChatTemplate
OutputKey string
}

var DefaultStructuredOutputTemplate = StructuredOutputOptions{
CallbackOptions: &schema.CallbackOptions{},
Prompt: prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewHumanMessageTemplate(defaultStructuredOutputTemplate),
}),
OutputKey: "output",
}

// StructuredOutput is a chain that generates structured output using a ChatModel chain and candidate values.
type StructuredOutput struct {
chatModelChain *ChatModel
Expand All @@ -39,13 +53,9 @@ type StructuredOutput struct {
}

// NewStructuredOutput creates a new StructuredOutput chain with the given ChatModel, prompt, and candidates.
func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, candidates []OutputCandidate, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) {
opts := StructuredOutputOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
OutputKey: "output",
}
func NewStructuredOutput(chatModel schema.ChatModel, candidates []OutputCandidate, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) {
opts := DefaultStructuredOutputTemplate
opts.Verbose = golc.Verbose

for _, fn := range optFns {
fn(&opts)
Expand Down Expand Up @@ -75,7 +85,10 @@ func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate,
})
}

chatModelChain, err := NewChatModelWithFunctions(chatModel, prompt, functions)
chatModelChain, err := NewChatModelWithFunctions(chatModel, opts.Prompt, functions, func(o *ChatModelOptions) {
o.CallbackOptions = opts.CallbackOptions
o.ForceFunctionCall = true
})
if err != nil {
return nil, err
}
Expand All @@ -101,12 +114,13 @@ func (c *StructuredOutput) Call(ctx context.Context, inputs schema.ChainValues,
output, err := golc.Call(ctx, c.chatModelChain, inputs, func(sco *golc.CallOptions) {
sco.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
sco.ParentRunID = opts.CallbackManger.RunID()
sco.Stop = opts.Stop
})
if err != nil {
return nil, err
}

aiMsg, ok := output["message"].(*schema.AIChatMessage)
aiMsg, ok := output[c.chatModelChain.OutputKeys()[0]].(*schema.AIChatMessage)
if !ok {
return nil, errors.New("unexpected output: message is not a ai chat message")
}
Expand Down
8 changes: 1 addition & 7 deletions chain/structured_output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
"github.com/stretchr/testify/require"
)
Expand All @@ -29,11 +28,6 @@ func TestStructuredOutput(t *testing.T) {
}, nil
})

// Create a dummy prompt template for testing
promptTemplate := prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewHumanMessageTemplate("{{.input}}"),
})

// Create a dummy output candidate
type person struct {
Name string `json:"name" description:"The person's name"`
Expand All @@ -42,7 +36,7 @@ func TestStructuredOutput(t *testing.T) {
}

// Create a new StructuredOutput chain
structuredOutputChain, err := NewStructuredOutput(chatModel, promptTemplate, []OutputCandidate{
structuredOutputChain, err := NewStructuredOutput(chatModel, []OutputCandidate{
{
Name: "Person",
Description: "Identifying information about a person",
Expand Down
13 changes: 10 additions & 3 deletions chain/tagging.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,22 @@ type Tagging struct {
// NewTagging creates a new Tagging chain with the provided chat model, structured output data, and optional options.
// It returns a Tagging chain or an error if the creation fails.
func NewTagging(chatModel schema.ChatModel, data any, optFns ...func(o *StructuredOutputOptions)) (*Tagging, error) {
pt := prompt.NewChatTemplate([]prompt.MessageTemplate{
opts := DefaultStructuredOutputTemplate
opts.Prompt = prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewHumanMessageTemplate(defaultTaggingTemplate),
})

so, err := NewStructuredOutput(chatModel, pt, []OutputCandidate{{
for _, fn := range optFns {
fn(&opts)
}

so, err := NewStructuredOutput(chatModel, []OutputCandidate{{
Name: "InformationExtraction",
Description: "Extracts the relevant information from the passage.",
Data: data,
}}, optFns...)
}}, func(o *StructuredOutputOptions) {
*o = opts
})
if err != nil {
return nil, err
}
Expand Down
12 changes: 4 additions & 8 deletions examples/structured_output_chain/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -20,6 +19,8 @@ type Person struct {
}

func main() {
golc.Verbose = true

chatModel, err := chatmodel.NewOpenAI(os.Getenv("OPENAI_API_KEY"), func(o *chatmodel.OpenAIOptions) {
o.ModelName = "gpt-4"
o.Temperature = 0
Expand All @@ -28,15 +29,10 @@ func main() {
log.Fatal(err)
}

pt := prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewSystemMessageTemplate("You are a world class algorithm for extracting information in structured formats."),
prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"),
})

structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, []chain.OutputCandidate{
structuredOutputChain, err := chain.NewStructuredOutput(chatModel, []chain.OutputCandidate{
{
Name: "Person",
Description: "Identifying information about a person",
Description: "Information about a person",
Data: &Person{},
},
})
Expand Down
12 changes: 11 additions & 1 deletion model/chatmodel/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,16 @@ func (cm *OpenAI) Generate(ctx context.Context, messages schema.ChatMessages, op
FrequencyPenalty: cm.opts.PresencePenalty,
Messages: openAIMessages,
Tools: tools,
ToolChoice: "auto",
Stop: opts.Stop,
}

if opts.ForceFunctionCall && len(opts.Functions) == 1 {
request.ToolChoice = openai.ToolChoice{Type: openai.ToolTypeFunction, Function: openai.ToolFunction{
Name: opts.Functions[0].Name,
}}
}

choices := []openai.ChatCompletionChoice{}
tokenUsage := make(map[string]int)

Expand Down Expand Up @@ -200,8 +207,11 @@ func (cm *OpenAI) Generate(ctx context.Context, messages schema.ChatMessages, op

role = res.Choices[0].Delta.Role
tokens = append(tokens, res.Choices[0].Delta.Content)
functionCall = res.Choices[0].Delta.FunctionCall
finishReason = res.Choices[0].FinishReason

if len(res.Choices[0].Delta.ToolCalls) > 0 {
functionCall = &res.Choices[0].Delta.ToolCalls[0].Function
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions model/chatmodel/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestOpenAI_Generate(t *testing.T) {
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there"},
},
ToolChoice: "auto",
}
mockResponse := openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{
Expand Down Expand Up @@ -74,6 +75,7 @@ func TestOpenAI_Generate(t *testing.T) {
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there"},
},
ToolChoice: "auto",
}
mockError := errors.New("generation error")
mockClient.createChatCompletionFn = func(ctx context.Context, request openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
Expand Down
10 changes: 6 additions & 4 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
)

type Options struct {
Stop []string
Callbacks []schema.Callback
ParentRunID string
Functions []schema.FunctionDefinition
Stop []string
Callbacks []schema.Callback
ParentRunID string
Functions []schema.FunctionDefinition
ForceFunctionCall bool
}

func GeneratePrompt(ctx context.Context, model schema.Model, promptValue schema.PromptValue, optFns ...func(o *Options)) (*schema.ModelResult, error) {
Expand Down Expand Up @@ -97,6 +98,7 @@ func ChatModelGenerate(ctx context.Context, model schema.ChatModel, messages sch
o.CallbackManger = rm
o.Stop = opts.Stop
o.Functions = opts.Functions
o.ForceFunctionCall = opts.ForceFunctionCall
})
if err != nil {
if cbErr := rm.OnModelError(ctx, &schema.ModelErrorManagerInput{
Expand Down
7 changes: 4 additions & 3 deletions schema/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ type FunctionDefinition struct {
}

type GenerateOptions struct {
CallbackManger CallbackManagerForModelRun
Stop []string
Functions []FunctionDefinition
CallbackManger CallbackManagerForModelRun
Stop []string
Functions []FunctionDefinition
ForceFunctionCall bool
}

// LLM is the interface for language models.
Expand Down

0 comments on commit 2bf4a97

Please sign in to comment.