Skip to content

Commit

Permalink
Refactor embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Dec 15, 2023
1 parent a05341b commit d836ce9
Show file tree
Hide file tree
Showing 27 changed files with 232 additions and 286 deletions.
48 changes: 36 additions & 12 deletions embedding/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc/schema"
"golang.org/x/sync/errgroup"
)

// Compile time check to ensure Bedrock satisfies the Embedder interface.
var _ schema.Embedder = (*Bedrock)(nil)

// amazonOutput represents the expected JSON output structure from the Bedrock model.
type amazonOutput struct {
Embedding []float64 `json:"embedding"`
Embedding []float32 `json:"embedding"`
}

// BedrockRuntimeClient is an interface for the Bedrock model runtime client.
Expand All @@ -27,6 +28,8 @@ type BedrockOptions struct {
*schema.CallbackOptions `map:"-"`
schema.Tokenizer `map:"-"`

MaxConcurrency int

// Model id to use.
ModelID string `map:"model_id,omitempty"`

Expand All @@ -43,7 +46,8 @@ type Bedrock struct {
// NewBedrock creates a new instance of Bedrock with the provided BedrockRuntimeClient and optional configuration.
func NewBedrock(client BedrockRuntimeClient, optFns ...func(o *BedrockOptions)) *Bedrock {
opts := BedrockOptions{
ModelID: "amazon.titan-embed-text-v1",
MaxConcurrency: 5,
ModelID: "amazon.titan-embed-text-v1",
}

for _, fn := range optFns {
Expand All @@ -56,24 +60,44 @@ func NewBedrock(client BedrockRuntimeClient, optFns ...func(o *BedrockOptions))
}
}

// EmbedDocuments embeds a list of documents and returns their embeddings.
func (e *Bedrock) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
embeddings := make([][]float64, len(texts))
// BatchEmbedText embeds a list of texts and returns their embeddings.
func (e *Bedrock) BatchEmbedText(ctx context.Context, texts []string) ([][]float32, error) {
errs, errctx := errgroup.WithContext(ctx)

// Use a semaphore to control concurrency
sem := make(chan struct{}, e.opts.MaxConcurrency)

embeddings := make([][]float32, len(texts))

for i, text := range texts {
embedding, err := e.EmbedQuery(ctx, text)
if err != nil {
return nil, err
}
// Acquire semaphore, limit concurrency
sem <- struct{}{}

i, text := i, text

errs.Go(func() error {
defer func() { <-sem }() // Release semaphore when done

embeddings[i] = embedding
embedding, err := e.EmbedText(errctx, text)
if err != nil {
return err
}

embeddings[i] = embedding

return nil
})
}

if err := errs.Wait(); err != nil {
return nil, err
}

return embeddings, nil
}

// EmbedQuery embeds a single query and returns its embedding.
func (e *Bedrock) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
// EmbedText embeds a single text and returns its embedding.
func (e *Bedrock) EmbedText(ctx context.Context, text string) ([]float32, error) {
jsonBody := map[string]string{
"inputText": removeNewLines(text),
}
Expand Down
18 changes: 9 additions & 9 deletions embedding/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestBedrock(t *testing.T) {
texts := []string{"text1", "text2"}

// Embed the documents.
embeddings, err := embedder.EmbedDocuments(context.Background(), texts)
embeddings, err := embedder.BatchEmbedText(context.Background(), texts)

// Add your assertions using testify
assert.NoError(t, err, "Expected no error")
Expand All @@ -44,11 +44,11 @@ func TestBedrock(t *testing.T) {
}
embedder := NewBedrock(client)

// Define a query text.
query := "query text"
// Define a text.
text := "text"

// Embed the query.
embedding, err := embedder.EmbedQuery(context.Background(), query)
// Embed the text.
embedding, err := embedder.EmbedText(context.Background(), text)

// Add your assertions using testify
assert.NoError(t, err, "Expected no error")
Expand All @@ -63,11 +63,11 @@ func TestBedrock(t *testing.T) {
}
embedder := NewBedrock(client)

// Define a query text.
query := "query text"
// Define a text.
text := "text"

// Embed the query.
embedding, err := embedder.EmbedQuery(context.Background(), query)
// Embed the text.
embedding, err := embedder.EmbedText(context.Background(), text)

// Add your assertions using testify
assert.Error(t, err, "Expected an error")
Expand Down
42 changes: 24 additions & 18 deletions embedding/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"errors"

"github.com/avast/retry-go"
"github.com/cohere-ai/cohere-go"
cohere "github.com/cohere-ai/cohere-go/v2"
cohereclient "github.com/cohere-ai/cohere-go/v2/client"
core "github.com/cohere-ai/cohere-go/v2/core"
"github.com/hupe1980/golc/internal/util"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -14,7 +17,7 @@ var _ schema.Embedder = (*Cohere)(nil)

// CohereClient is an interface for the Cohere client.
type CohereClient interface {
Embed(opts cohere.EmbedOptions) (*cohere.EmbedResponse, error)
Embed(ctx context.Context, request *cohere.EmbedRequest) (*cohere.EmbedResponse, error)
}

// CohereOptions contains options for configuring the Cohere instance.
Expand All @@ -36,10 +39,7 @@ type Cohere struct {
// NewCohere creates a new Cohere instance with the provided API key and options.
// It returns the initialized Cohere instance or an error if initialization fails.
func NewCohere(apiKey string, optFns ...func(o *CohereOptions)) (*Cohere, error) {
client, err := cohere.CreateClient(apiKey)
if err != nil {
return nil, err
}
client := cohereclient.NewClient(cohereclient.WithToken(apiKey))

return NewCohereFromClient(client, optFns...)
}
Expand All @@ -50,6 +50,7 @@ func NewCohereFromClient(client CohereClient, optFns ...func(o *CohereOptions))
opts := CohereOptions{
Model: "embed-english-v3.0",
MaxRetries: 3,
Truncate: "NONE",
}

for _, fn := range optFns {
Expand All @@ -62,26 +63,31 @@ func NewCohereFromClient(client CohereClient, optFns ...func(o *CohereOptions))
}, nil
}

// EmbedDocuments embeds a list of documents and returns their embeddings.
func (e *Cohere) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
res, err := e.embedWithRetry(cohere.EmbedOptions{
Model: e.opts.Model,
Truncate: e.opts.Truncate,
// BatchEmbedText embeds a list of texts and returns their embeddings.
func (e *Cohere) BatchEmbedText(ctx context.Context, texts []string) ([][]float32, error) {
res, err := e.embedWithRetry(ctx, &cohere.EmbedRequest{
Model: util.AddrOrNil(e.opts.Model),
Truncate: cohere.EmbedRequestTruncate(e.opts.Truncate).Ptr(),
Texts: texts,
})
if err != nil {
return nil, err
}

return res.Embeddings, nil
embeddings := make([][]float32, len(res.Embeddings))
for i, r := range res.Embeddings {
embeddings[i] = util.Float64ToFloat32(r)
}

return embeddings, nil
}

func (e *Cohere) embedWithRetry(opts cohere.EmbedOptions) (*cohere.EmbedResponse, error) {
func (e *Cohere) embedWithRetry(ctx context.Context, req *cohere.EmbedRequest) (*cohere.EmbedResponse, error) {
retryOpts := []retry.Option{
retry.Attempts(e.opts.MaxRetries),
retry.DelayType(retry.FixedDelay),
retry.RetryIf(func(err error) bool {
e := &cohere.APIError{}
e := new(core.APIError)
if errors.As(err, &e) {
switch e.StatusCode {
case 429, 500:
Expand All @@ -99,7 +105,7 @@ func (e *Cohere) embedWithRetry(opts cohere.EmbedOptions) (*cohere.EmbedResponse

err := retry.Do(
func() error {
r, cErr := e.client.Embed(opts)
r, cErr := e.client.Embed(ctx, req)
if cErr != nil {
return cErr
}
Expand All @@ -112,9 +118,9 @@ func (e *Cohere) embedWithRetry(opts cohere.EmbedOptions) (*cohere.EmbedResponse
return res, err
}

// EmbedQuery embeds a single query and returns its embedding.
func (e *Cohere) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
embeddings, err := e.EmbedDocuments(ctx, []string{text})
// EmbedText embeds a single query and returns its embedding.
func (e *Cohere) EmbedText(ctx context.Context, text string) ([]float32, error) {
embeddings, err := e.BatchEmbedText(ctx, []string{text})
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions embedding/cohere_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"
"testing"

"github.com/cohere-ai/cohere-go"
cohere "github.com/cohere-ai/cohere-go/v2"
"github.com/stretchr/testify/assert"
)

Expand All @@ -30,7 +30,7 @@ func TestCohere(t *testing.T) {
texts := []string{"text1", "text2"}

// Embed the documents.
embeddings, err := cohereModel.EmbedDocuments(context.Background(), texts)
embeddings, err := cohereModel.BatchEmbedText(context.Background(), texts)
assert.NoError(t, err, "Expected no error")
assert.NotNil(t, embeddings, "Expected non-nil embeddings")
assert.Len(t, embeddings, 2, "Expected 2 embeddings")
Expand All @@ -57,7 +57,7 @@ func TestCohere(t *testing.T) {
query := "query text"

// Embed the query.
embedding, err := cohereModel.EmbedQuery(context.Background(), query)
embedding, err := cohereModel.EmbedText(context.Background(), query)
assert.NoError(t, err, "Expected no error")
assert.NotNil(t, embedding, "Expected non-nil embedding")
assert.Len(t, embedding, 3, "Expected 3 values in the embedding")
Expand All @@ -82,7 +82,7 @@ func TestCohere(t *testing.T) {
query := "query text"

// Embed the query.
embedding, err := cohereModel.EmbedQuery(context.Background(), query)
embedding, err := cohereModel.EmbedText(context.Background(), query)
assert.Error(t, err, "Expected an error")
assert.Nil(t, embedding, "Expected nil embedding")
})
Expand All @@ -95,7 +95,7 @@ type mockCohereClient struct {
err error
}

func (m *mockCohereClient) Embed(opts cohere.EmbedOptions) (*cohere.EmbedResponse, error) {
func (m *mockCohereClient) Embed(ctx context.Context, request *cohere.EmbedRequest) (*cohere.EmbedResponse, error) {
if m.err != nil {
return nil, m.err
}
Expand Down
10 changes: 5 additions & 5 deletions embedding/ernie.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ func NewErnieFromClient(client ErnieClient, optFns ...func(o *ErnieOptions)) *Er
}
}

// EmbedDocuments embeds a list of documents and returns their embeddings.
func (e *Ernie) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
// BatchEmbedText embeds a list of texts and returns their embeddings.
func (e *Ernie) BatchEmbedText(ctx context.Context, texts []string) ([][]float32, error) {
// The number of texts does not exceed 16
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu
chunks := util.ChunkBy(texts, e.chunkSize)

embeddings := make([][]float64, len(texts))
embeddings := make([][]float32, len(texts))

for i, chunk := range chunks {
res, err := e.client.CreateEmbedding(ctx, e.opts.Model, ernie.EmbeddingRequest{
Expand All @@ -77,8 +77,8 @@ func (e *Ernie) EmbedDocuments(ctx context.Context, texts []string) ([][]float64
return embeddings, nil
}

// EmbedQuery embeds a single query and returns its embedding.
func (e *Ernie) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
// EmbedText embeds a single text and returns its embedding.
func (e *Ernie) EmbedText(ctx context.Context, text string) ([]float32, error) {
res, err := e.client.CreateEmbedding(ctx, e.opts.Model, ernie.EmbeddingRequest{
Input: []string{text},
})
Expand Down
24 changes: 12 additions & 12 deletions embedding/ernie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ func TestEmbedding(t *testing.T) {
Object: "fakeObject",
Data: []struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}{
{
Object: "text1",
Embedding: []float64{1.0, 2.0, 3.0},
Embedding: []float32{1.0, 2.0, 3.0},
Index: 0,
},
{
Object: "text2",
Embedding: []float64{4.0, 5.0, 6.0},
Embedding: []float32{4.0, 5.0, 6.0},
Index: 1,
},
{
Object: "text3",
Embedding: []float64{7.0, 8.0, 9.0},
Embedding: []float32{7.0, 8.0, 9.0},
Index: 2,
},
},
Expand All @@ -44,14 +44,14 @@ func TestEmbedding(t *testing.T) {
texts := []string{"text1", "text2", "text3"}

// Test embedding of documents.
embeddings, err := ernieEmbed.EmbedDocuments(ctx, texts)
embeddings, err := ernieEmbed.BatchEmbedText(ctx, texts)

assert.NoError(t, err, "Error embedding documents")
assert.Len(t, embeddings, len(texts), "Unexpected number of embeddings")

assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0}, embeddings[0])
assert.ElementsMatch(t, []float64{4.0, 5.0, 6.0}, embeddings[1])
assert.ElementsMatch(t, []float64{7.0, 8.0, 9.0}, embeddings[2])
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0}, embeddings[0])
assert.ElementsMatch(t, []float32{4.0, 5.0, 6.0}, embeddings[1])
assert.ElementsMatch(t, []float32{7.0, 8.0, 9.0}, embeddings[2])
})

t.Run("EmbedQuery", func(t *testing.T) {
Expand All @@ -62,12 +62,12 @@ func TestEmbedding(t *testing.T) {
Object: "fakeObject",
Data: []struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}{
{
Object: "fakeEmbedding",
Embedding: []float64{1.0, 2.0, 3.0},
Embedding: []float32{1.0, 2.0, 3.0},
Index: 0,
},
},
Expand All @@ -79,11 +79,11 @@ func TestEmbedding(t *testing.T) {
query := "queryText"

// Test embedding of a query.
embedding, err := ernieEmbed.EmbedQuery(ctx, query)
embedding, err := ernieEmbed.EmbedText(ctx, query)

assert.NoError(t, err, "Error embedding query")
assert.Len(t, embedding, 3, "Unexpected embedding dimensions")
expected := []float64{1.0, 2.0, 3.0} // Mocked embedding values
expected := []float32{1.0, 2.0, 3.0} // Mocked embedding values
assert.ElementsMatch(t, expected, embedding, "Embedding values do not match for the query")
})
}
Expand Down
Loading

0 comments on commit d836ce9

Please sign in to comment.