Skip to content

Commit

Permalink
fix: add locking in connection cache
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-2711 committed Dec 21, 2023
1 parent 9439185 commit 59928b6
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 11 deletions.
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ConnectionConfig struct {
}

// SetDefaultClient sets the default MongoDB client to be used by the package.
func SetDefaultClient(client *mongo.Client, dbName string) {
func SetDefaultClient(client *mongo.Client) {
mClient = client
}

Expand Down
45 changes: 39 additions & 6 deletions connection_cache.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,54 @@
package mgod

import "go.mongodb.org/mongo-driver/mongo"
import (
"sync"

"go.mongodb.org/mongo-driver/mongo"
)

// dbConnCache is a cache of MongoDB database connections.
var dbConnCache map[string]*mongo.Database
var dbConnCache *connectionCache

func init() {
dbConnCache = make(map[string]*mongo.Database)
dbConnCache = newConnectionCache()
}

// connectionCache is a thread safe construct to cache MongoDB database connections.
type connectionCache struct {
cache map[string]*mongo.Database
mux sync.RWMutex
}

func newConnectionCache() *connectionCache {
return &connectionCache{
cache: map[string]*mongo.Database{},
}
}

func (c *connectionCache) Get(dbName string) *mongo.Database {
c.mux.RLock()
defer c.mux.RUnlock()

return c.cache[dbName]
}

func (c *connectionCache) Set(dbName string, db *mongo.Database) {
c.mux.Lock()
defer c.mux.Unlock()

c.cache[dbName] = db
}

// getDBConn returns a MongoDB database connection from the cache.
// If the connection is not present in the cache, it creates a new connection and adds it to the cache (Write-through policy).
func getDBConn(dbName string) *mongo.Database {
dbConn := dbConnCache.Get(dbName)

// Initialize the cache entry if it is not present.
if dbConnCache[dbName] == nil {
dbConnCache[dbName] = mClient.Database(dbName)
if dbConn == nil {
dbConn = mClient.Database(dbName)
dbConnCache.Set(dbName, dbConn)
}

return dbConnCache[dbName]
return dbConn
}
2 changes: 1 addition & 1 deletion docs/basic_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import "github.com/Lyearn/mgod"

func init() {
// client is the MongoDB client obtained using Go Mongo Driver's Connect method.
mgod.SetDefaultClient(client, dbName)
mgod.SetDefaultClient(client)
}
```

Expand Down
2 changes: 0 additions & 2 deletions docs/multi_tenancy.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ The `EntityMongoModel` is always bound to the specified database at the time of
:::

```go
amount := 10000

result, _ := tenant1Model.FindOne(context.TODO(), bson.M{"name": "Gopher Tenant 2"})
// result will be <nil> value in this case
```
5 changes: 4 additions & 1 deletion entity_mongo_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ func (s *EntityMongoModelSuite) setupData() {
}

func (s *EntityMongoModelSuite) getModel() mgod.EntityMongoModel[testEntity] {
dbName := "mgoddb"
collection := "entityMongoModel"
schemaOpts := schemaopt.SchemaOptions{Timestamps: true}
opts := mgod.NewEntityMongoModelOptions("mgoddb", "entityMongoModel", &schemaOpts)

opts := mgod.NewEntityMongoModelOptions(dbName, collection, &schemaOpts)
model, err := mgod.NewEntityMongoModel(testEntity{}, *opts)
if err != nil {
s.T().Fatal(err)
Expand Down

0 comments on commit 59928b6

Please sign in to comment.