Skip to content

Commit

Permalink
Merge pull request #11 from tkawachi/streaming
Browse files Browse the repository at this point in the history
Add non-streaming flag
  • Loading branch information
mergify[bot] authored Mar 5, 2023
2 parents 88b1c34 + 0d6dd1b commit b4b23db
Showing 1 changed file with 67 additions and 20 deletions.
87 changes: 67 additions & 20 deletions aichat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package main
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log"
"os"
"strings"
Expand All @@ -13,15 +15,53 @@ import (
)

type chatOptions struct {
temperature float32
maxTokens int
temperature float32
maxTokens int
nonStreaming bool
}

type AIChat struct {
client *gogpt.Client
options chatOptions
}

// stramCompletion print out the chat completion in streaming mode.
func streamCompletion(client *gogpt.Client, request gogpt.ChatCompletionRequest, out io.Writer) error {
stream, err := client.CreateChatCompletionStream(context.Background(), request)
if err != nil {
return err
}
defer stream.Close()
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
fmt.Println()
break
}
if err != nil {
return fmt.Errorf("stream recv: %w", err)
}
_, err = fmt.Fprint(out, response.Choices[0].Delta.Content)
if err != nil {
return err
}
}
return nil
}

// stramCompletion print out the chat completion in non-streaming mode.
func nonStreamCompletion(client *gogpt.Client, request gogpt.ChatCompletionRequest, out io.Writer) error {
response, err := client.CreateChatCompletion(context.Background(), request)
if err != nil {
return err
}
if len(response.Choices) == 0 {
return fmt.Errorf("no choices returned")
}
_, err = fmt.Fprint(out, response.Choices[0].Message.Content+"\n")
return err
}

func (aiChat *AIChat) stdChatLoop() error {
messages := []gogpt.ChatCompletionMessage{}
scanner := bufio.NewScanner(os.Stdin)
Expand All @@ -36,19 +76,22 @@ func (aiChat *AIChat) stdChatLoop() error {
Role: gogpt.ChatMessageRoleUser,
Content: input,
})
response, err := aiChat.client.CreateChatCompletion(context.Background(), gogpt.ChatCompletionRequest{
fmt.Print("assistant: ")
request := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: messages,
Temperature: aiChat.options.temperature,
MaxTokens: aiChat.options.maxTokens,
})
}
var err error
if aiChat.options.nonStreaming {
err = nonStreamCompletion(aiChat.client, request, os.Stdout)
} else {
err = streamCompletion(aiChat.client, request, os.Stdout)
}
if err != nil {
return err
}
if len(response.Choices) == 0 {
return fmt.Errorf("no choices")
}
fmt.Println("assistant: " + response.Choices[0].Message.Content)
fmt.Print("user: ")
}
return scanner.Err()
Expand Down Expand Up @@ -77,10 +120,12 @@ func main() {
var maxTokens = 500
var verbose = false
var listPrompts = false
getopt.Flag(&temperature, 't', "temperature", "temperature")
getopt.Flag(&maxTokens, 'm', "max-tokens", "max tokens")
getopt.Flag(&verbose, 'v', "verbose", "verbose")
getopt.Flag(&listPrompts, 'l', "list-prompts", "list prompts")
var nonStreaming = false
getopt.FlagLong(&temperature, "temperature", 't', "temperature")
getopt.FlagLong(&maxTokens, "max-tokens", 'm', "max tokens")
getopt.FlagLong(&verbose, "verbose", 'v', "verbose output")
getopt.FlagLong(&listPrompts, "list-prompts", 'l', "list prompts")
getopt.FlagLong(&nonStreaming, "non-streaming", 0, "non streaming mode")
getopt.Parse()

if listPrompts {
Expand All @@ -95,8 +140,9 @@ func main() {
log.Fatal(err)
}
options := chatOptions{
temperature: temperature,
maxTokens: maxTokens,
temperature: temperature,
maxTokens: maxTokens,
nonStreaming: nonStreaming,
}
if verbose {
log.Printf("options: %+v", options)
Expand Down Expand Up @@ -130,19 +176,20 @@ func main() {
if verbose {
log.Printf("messages: %+v", messages)
}
response, err := aiChat.client.CreateChatCompletion(context.Background(), gogpt.ChatCompletionRequest{
request := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: messages,
Temperature: firstNonZeroFloat32(prompt.Temperature, aiChat.options.temperature),
MaxTokens: firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens),
})
}
if aiChat.options.nonStreaming {
err = nonStreamCompletion(aiChat.client, request, os.Stdout)
} else {
err = streamCompletion(aiChat.client, request, os.Stdout)
}
if err != nil {
log.Fatal(err)
}
if len(response.Choices) == 0 {
log.Fatal("no choices")
}
fmt.Println(response.Choices[0].Message.Content)
}

}

0 comments on commit b4b23db

Please sign in to comment.