Skip to content

Commit

Permalink
replace map with syncmap
Browse files Browse the repository at this point in the history
Signed-off-by: Gabriele Santomaggio <G.santomaggio@gmail.com>
  • Loading branch information
Gsantomaggio committed Jan 21, 2025
1 parent feb6917 commit 483c3c3
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 38 deletions.
90 changes: 52 additions & 38 deletions pkg/stream/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,32 +470,41 @@ func (envOptions *EnvironmentOptions) SetRPCTimeout(timeout time.Duration) *Envi

type environmentCoordinator struct {
mutex *sync.Mutex
mutexContext *sync.RWMutex
clientsPerContext map[int]*Client
clientsPerContext sync.Map
maxItemsForClient int
nextId int
}

func (cc *environmentCoordinator) isProducerListFull(clientsPerContextId int) bool {
return cc.clientsPerContext[clientsPerContextId].coordinator.
ProducersCount() >= cc.maxItemsForClient
client, ok := cc.clientsPerContext.Load(clientsPerContextId)
if !ok {
logs.LogError("client not found")
return false
}
return client.(*Client).coordinator.ProducersCount() >= cc.maxItemsForClient

}

func (cc *environmentCoordinator) isConsumerListFull(clientsPerContextId int) bool {
return cc.clientsPerContext[clientsPerContextId].coordinator.
ConsumersCount() >= cc.maxItemsForClient
client, ok := cc.clientsPerContext.Load(clientsPerContextId)
if !ok {
logs.LogError("client not found")
return false
}
return client.(*Client).coordinator.ConsumersCount() >= cc.maxItemsForClient
}

func (cc *environmentCoordinator) maybeCleanClients() {
cc.mutex.Lock()
defer cc.mutex.Unlock()
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
for i, client := range cc.clientsPerContext {

cc.clientsPerContext.Range(func(key, value any) bool {
client := value.(*Client)
if !client.socket.isOpen() {
delete(cc.clientsPerContext, i)
cc.clientsPerContext.Delete(key)
}
}
return true
})
}

func (c *Client) maybeCleanProducers(streamName string) {
Expand Down Expand Up @@ -541,15 +550,16 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP
options *ProducerOptions, rpcTimeout time.Duration) (*Producer, error) {
cc.mutex.Lock()
defer cc.mutex.Unlock()
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
var clientResult *Client
for i, client := range cc.clientsPerContext {
if !cc.isProducerListFull(i) {
clientResult = client
break

cc.clientsPerContext.Range(func(key, value any) bool {
if !cc.isProducerListFull(key.(int)) {
clientResult = value.(*Client)
return false
}
}
return true
})

clientProvidedName := "go-stream-producer"
if options != nil && options.ClientProvidedName != "" {
clientProvidedName = options.ClientProvidedName
Expand Down Expand Up @@ -593,7 +603,8 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP
func (cc *environmentCoordinator) newClientForProducer(connectionName string, leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, rpcTimeOut time.Duration) *Client {
clientResult := newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeOut)
cc.nextId++
cc.clientsPerContext[cc.nextId] = clientResult

cc.clientsPerContext.Store(cc.nextId, clientResult)
return clientResult
}

Expand All @@ -602,20 +613,20 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro
options *ConsumerOptions, rpcTimeout time.Duration) (*Consumer, error) {
cc.mutex.Lock()
defer cc.mutex.Unlock()
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
var clientResult *Client
for i, client := range cc.clientsPerContext {
if !cc.isConsumerListFull(i) {
clientResult = client
break

cc.clientsPerContext.Range(func(key, value any) bool {
if !cc.isConsumerListFull(key.(int)) {
clientResult = value.(*Client)
return false
}
}
return true
})

if clientResult == nil {
clientResult = newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeout)
cc.nextId++
cc.clientsPerContext[cc.nextId] = clientResult
cc.clientsPerContext.Store(cc.nextId, clientResult)
}
// try to reconnect in case the socket is closed
err := clientResult.connect()
Expand All @@ -632,23 +643,28 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro
}

func (cc *environmentCoordinator) Close() error {
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
for _, client := range cc.clientsPerContext {

cc.clientsPerContext.Range(func(key, value any) bool {
client := value.(*Client)
for i := range client.coordinator.producers {
_ = client.coordinator.producers[i].(*Producer).Close()
}
for i := range client.coordinator.consumers {
_ = client.coordinator.consumers[i].(*Consumer).Close()
}
}
return true
})

return nil
}

func (cc *environmentCoordinator) getClientsPerContext() map[int]*Client {
cc.mutexContext.Lock()
defer cc.mutexContext.Unlock()
return cc.clientsPerContext
clients := map[int]*Client{}
cc.clientsPerContext.Range(func(key, value any) bool {
clients[key.(int)] = value.(*Client)
return true
})
return clients
}

type producersEnvironment struct {
Expand Down Expand Up @@ -677,10 +693,9 @@ func (ps *producersEnvironment) newProducer(clientLocator *Client, streamName st
coordinatorKey := leader.hostPort()
if ps.producersCoordinator[coordinatorKey] == nil {
ps.producersCoordinator[coordinatorKey] = &environmentCoordinator{
clientsPerContext: map[int]*Client{},
clientsPerContext: sync.Map{},
mutex: &sync.Mutex{},
maxItemsForClient: ps.maxItemsForClient,
mutexContext: &sync.RWMutex{},
nextId: 0,
}
}
Expand Down Expand Up @@ -742,10 +757,9 @@ func (ps *consumersEnvironment) NewSubscriber(clientLocator *Client, streamName
coordinatorKey := consumerBroker.hostPort()
if ps.consumersCoordinator[coordinatorKey] == nil {
ps.consumersCoordinator[coordinatorKey] = &environmentCoordinator{
clientsPerContext: map[int]*Client{},
clientsPerContext: sync.Map{},
mutex: &sync.Mutex{},
maxItemsForClient: ps.maxItemsForClient,
mutexContext: &sync.RWMutex{},
nextId: 0,
}
}
Expand Down
57 changes: 57 additions & 0 deletions pkg/stream/environment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package stream

import (
"crypto/tls"
"github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp"
"sync"
"time"

Expand Down Expand Up @@ -419,4 +420,60 @@ var _ = Describe("Environment test", func() {
Expect(env.Close()).NotTo(HaveOccurred())
})

It("close env should close all the producers and consumers ", func() {
env, err := NewEnvironment(NewEnvironmentOptions().
SetMaxConsumersPerClient(2).
SetMaxConsumersPerClient(3))

Expect(err).NotTo(HaveOccurred())
streamName := uuid.New().String()
Expect(env.DeclareStream(streamName, nil)).NotTo(HaveOccurred())
for i := 0; i < 5; i++ {
_, err := env.NewProducer(streamName, nil)
Expect(err).NotTo(HaveOccurred())
}

for i := 0; i < 5; i++ {
_, err := env.NewConsumer(streamName, func(consumerContext ConsumerContext, message *amqp.Message) {

}, nil)
Expect(err).NotTo(HaveOccurred())
}

// count element sync map
count := 0
env.consumers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool {
Expect(value).NotTo(BeNil())
count++
return true
})

Expect(count).To(Equal(2))

Expect(env.Close()).NotTo(HaveOccurred())

// count element sync map

Eventually(func() int {
count = 0
env.producers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool {
Expect(value).To(BeNil())
count++
return true
})
return count
}, "5s", "1s").Should(Equal(0))

Eventually(func() int {
count = 0
env.consumers.getCoordinators()["localhost:5552"].clientsPerContext.Range(func(key, value any) bool {
Expect(value).To(BeNil())
count++
return true
})
return count
}, "5s", "1s").Should(Equal(0))

})

})

0 comments on commit 483c3c3

Please sign in to comment.