diff --git a/vectorstore/in_memory.go b/vectorstore/in_memory.go index 10e079d..5464754 100644 --- a/vectorstore/in_memory.go +++ b/vectorstore/in_memory.go @@ -3,6 +3,8 @@ package vectorstore import ( "container/heap" "context" + "encoding/gob" + "io" "github.com/hupe1980/golc/internal/util" "github.com/hupe1980/golc/metric" @@ -195,3 +197,26 @@ func (vs *InMemory) SimilaritySearch(ctx context.Context, query string) ([]schem return documents, nil } + +func (vs *InMemory) Load(r io.Reader) error { + decoder := gob.NewDecoder(r) + + // Decode the data + if err := decoder.Decode(&vs.data); err != nil { + return err + } + + return nil +} + +// Save saves the data to an io.Writer. +func (vs *InMemory) Save(w io.Writer) error { + encoder := gob.NewEncoder(w) + + // Encode the data + if err := encoder.Encode(vs.data); err != nil { + return err + } + + return nil +} diff --git a/vectorstore/in_memory_test.go b/vectorstore/in_memory_test.go index 458f7d2..04a4837 100644 --- a/vectorstore/in_memory_test.go +++ b/vectorstore/in_memory_test.go @@ -1,10 +1,12 @@ package vectorstore import ( + "bytes" "context" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/hupe1980/golc/schema" ) @@ -52,6 +54,31 @@ func TestInMemory(t *testing.T) { assert.Equal(t, expectedDocuments[i].PageContent, doc.PageContent) } }) + + t.Run("SaveAndLoad", func(t *testing.T) { + originalData := []InMemoryItem{ + {Content: "item1", Vector: []float32{1.0, 2.0, 3.0}, Metadata: map[string]any{"key1": "value1"}}, + {Content: "item2", Vector: []float32{4.0, 5.0, 6.0}, Metadata: map[string]any{"key2": "value2"}}, + } + + // Create an InMemory instance with the original data + vsOriginal := &InMemory{data: originalData} + + // Serialize the original data + var buf bytes.Buffer + err := vsOriginal.Save(&buf) + require.NoError(t, err, "Failed to save data") + + // Create a new InMemory instance + vsLoaded := &InMemory{} + + // Load the serialized data + err = vsLoaded.Load(&buf) + require.NoError(t, err, "Failed to load data") + + // Check if the loaded data matches the original data + assert.Equal(t, originalData, vsLoaded.data, "Loaded data does not match original data") + }) } // mockEmbedder implements the schema.Embedder interface for testing purposes.