diff --git a/connection.go b/connection.go index 6b0ac10..a2435f0 100644 --- a/connection.go +++ b/connection.go @@ -18,7 +18,7 @@ type ConnectionConfig struct { } // SetDefaultClient sets the default MongoDB client to be used by the package. -func SetDefaultClient(client *mongo.Client, dbName string) { +func SetDefaultClient(client *mongo.Client) { mClient = client } diff --git a/connection_cache.go b/connection_cache.go index cc8ed96..8c9faeb 100644 --- a/connection_cache.go +++ b/connection_cache.go @@ -1,21 +1,54 @@ package mgod -import "go.mongodb.org/mongo-driver/mongo" +import ( + "sync" + + "go.mongodb.org/mongo-driver/mongo" +) // dbConnCache is a cache of MongoDB database connections. -var dbConnCache map[string]*mongo.Database +var dbConnCache *connectionCache func init() { - dbConnCache = make(map[string]*mongo.Database) + dbConnCache = newConnectionCache() +} + +// connectionCache is a thread safe construct to cache MongoDB database connections. +type connectionCache struct { + cache map[string]*mongo.Database + mux sync.RWMutex +} + +func newConnectionCache() *connectionCache { + return &connectionCache{ + cache: map[string]*mongo.Database{}, + } +} + +func (c *connectionCache) Get(dbName string) *mongo.Database { + c.mux.RLock() + defer c.mux.RUnlock() + + return c.cache[dbName] +} + +func (c *connectionCache) Set(dbName string, db *mongo.Database) { + c.mux.Lock() + defer c.mux.Unlock() + + c.cache[dbName] = db } // getDBConn returns a MongoDB database connection from the cache. // If the connection is not present in the cache, it creates a new connection and adds it to the cache (Write-through policy). func getDBConn(dbName string) *mongo.Database { + dbConn := dbConnCache.Get(dbName) + // Initialize the cache entry if it is not present. - if dbConnCache[dbName] == nil { - dbConnCache[dbName] = mClient.Database(dbName) + if dbConn == nil { + dbConn = mClient.Database(dbName) + dbConnCache.Set(dbName, dbConn) } - return dbConnCache[dbName] + return dbConn } diff --git a/docs/basic_usage.md b/docs/basic_usage.md index cca3861..0d96a28 100644 --- a/docs/basic_usage.md +++ b/docs/basic_usage.md @@ -10,7 +10,7 @@ import "github.com/Lyearn/mgod" func init() { // client is the MongoDB client obtained using Go Mongo Driver's Connect method. - mgod.SetDefaultClient(client, dbName) + mgod.SetDefaultClient(client) } ``` diff --git a/docs/multi_tenancy.md b/docs/multi_tenancy.md index 058b560..af41cc0 100644 --- a/docs/multi_tenancy.md +++ b/docs/multi_tenancy.md @@ -37,8 +37,6 @@ The `EntityMongoModel` is always bound to the specified database at the time of ::: ```go -amount := 10000 - result, _ := tenant1Model.FindOne(context.TODO(), bson.M{"name": "Gopher Tenant 2"}) // result will be value in this case ``` diff --git a/entity_mongo_model_test.go b/entity_mongo_model_test.go index f0bf1bc..f394324 100644 --- a/entity_mongo_model_test.go +++ b/entity_mongo_model_test.go @@ -86,8 +86,11 @@ func (s *EntityMongoModelSuite) setupData() { } func (s *EntityMongoModelSuite) getModel() mgod.EntityMongoModel[testEntity] { + dbName := "mgoddb" + collection := "entityMongoModel" schemaOpts := schemaopt.SchemaOptions{Timestamps: true} - opts := mgod.NewEntityMongoModelOptions("mgoddb", "entityMongoModel", &schemaOpts) + + opts := mgod.NewEntityMongoModelOptions(dbName, collection, &schemaOpts) model, err := mgod.NewEntityMongoModel(testEntity{}, *opts) if err != nil { s.T().Fatal(err)