Skip to content

Commit

Permalink
fix: make registerTask safe on concurrent use (#616)
Browse files Browse the repository at this point in the history
Co-authored-by: chengwanli <chengwanli@p1staff.com>
  • Loading branch information
HitmanRanbo and chengwanli authored Nov 5, 2020
1 parent 260989d commit 4554c7e
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 29 deletions.
16 changes: 12 additions & 4 deletions v1/common/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package common

import (
"errors"
"sync"

"github.com/RichardKnop/machinery/v1/brokers/iface"
"github.com/RichardKnop/machinery/v1/config"
Expand All @@ -10,10 +11,15 @@ import (
"github.com/RichardKnop/machinery/v1/tasks"
)

type registeredTaskNames struct {
sync.RWMutex
items []string
}

// Broker represents a base broker structure
type Broker struct {
cnf *config.Config
registeredTaskNames []string
registeredTaskNames registeredTaskNames
retry bool
retryFunc func(chan int)
retryStopChan chan int
Expand Down Expand Up @@ -62,12 +68,14 @@ func (b *Broker) Publish(signature *tasks.Signature) error {

// SetRegisteredTaskNames sets registered task names
func (b *Broker) SetRegisteredTaskNames(names []string) {
b.registeredTaskNames = names
b.registeredTaskNames.Lock()
defer b.registeredTaskNames.Unlock()
b.registeredTaskNames.items = names
}

// IsTaskRegistered returns true if the task is registered with this broker
func (b *Broker) IsTaskRegistered(name string) bool {
for _, registeredTaskName := range b.registeredTaskNames {
for _, registeredTaskName := range b.registeredTaskNames.items {
if registeredTaskName == name {
return true
}
Expand Down Expand Up @@ -110,7 +118,7 @@ func (b *Broker) StopConsuming() {

// GetRegisteredTaskNames returns registered tasks names
func (b *Broker) GetRegisteredTaskNames() []string {
return b.registeredTaskNames
return b.registeredTaskNames.items
}

// AdjustRoutingKey makes sure the routing key is correct.
Expand Down
28 changes: 16 additions & 12 deletions v1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// All the tasks workers process are registered against the server
type Server struct {
config *config.Config
registeredTasks map[string]interface{}
registeredTasks *sync.Map
broker brokersiface.Broker
backend backendsiface.Backend
lock lockiface.Lock
Expand All @@ -40,7 +40,7 @@ type Server struct {
func NewServerWithBrokerBackendLock(cnf *config.Config, brokerServer brokersiface.Broker, backendServer backendsiface.Backend, lock lockiface.Lock) *Server {
srv := &Server{
config: cnf,
registeredTasks: map[string]interface{}{},
registeredTasks: new(sync.Map),
broker: brokerServer,
backend: backendServer,
lock: lock,
Expand Down Expand Up @@ -143,7 +143,11 @@ func (server *Server) RegisterTasks(namedTaskFuncs map[string]interface{}) error
return err
}
}
server.registeredTasks = namedTaskFuncs

for k, v := range namedTaskFuncs {
server.registeredTasks.Store(k, v)
}

server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
return nil
}
Expand All @@ -153,20 +157,20 @@ func (server *Server) RegisterTask(name string, taskFunc interface{}) error {
if err := tasks.ValidateTask(taskFunc); err != nil {
return err
}
server.registeredTasks[name] = taskFunc
server.registeredTasks.Store(name, taskFunc)
server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
return nil
}

// IsTaskRegistered returns true if the task name is registered with this broker
func (server *Server) IsTaskRegistered(name string) bool {
_, ok := server.registeredTasks[name]
_, ok := server.registeredTasks.Load(name)
return ok
}

// GetRegisteredTask returns registered task by name
func (server *Server) GetRegisteredTask(name string) (interface{}, error) {
taskFunc, ok := server.registeredTasks[name]
taskFunc, ok := server.registeredTasks.Load(name)
if !ok {
return nil, fmt.Errorf("Task not registered error: %s", name)
}
Expand Down Expand Up @@ -340,12 +344,12 @@ func (server *Server) SendChord(chord *tasks.Chord, sendConcurrency int) (*resul

// GetRegisteredTaskNames returns slice of registered task names
func (server *Server) GetRegisteredTaskNames() []string {
taskNames := make([]string, len(server.registeredTasks))
var i = 0
for name := range server.registeredTasks {
taskNames[i] = name
i++
}
taskNames := make([]string, 0)

server.registeredTasks.Range(func(key, value interface{}) bool {
taskNames = append(taskNames, key.(string))
return true
})
return taskNames
}

Expand Down
14 changes: 14 additions & 0 deletions v1/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ func TestRegisterTask(t *testing.T) {
assert.NoError(t, err, "test_task is not registered but it should be")
}

func TestRegisterTaskInRaceCondition(t *testing.T) {
t.Parallel()

server := getTestServer(t)
for i:=0; i<10; i++ {
go func() {
err := server.RegisterTask("test_task", func() error { return nil })
assert.NoError(t, err)
_, err = server.GetRegisteredTask("test_task")
assert.NoError(t, err, "test_task is not registered but it should be")
}()
}
}

func TestGetRegisteredTask(t *testing.T) {
t.Parallel()

Expand Down
28 changes: 15 additions & 13 deletions v2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
// All the tasks workers process are registered against the server
type Server struct {
config *config.Config
registeredTasks map[string]interface{}
registeredTasks *sync.Map
broker brokersiface.Broker
backend backendsiface.Backend
lock lockiface.Lock
Expand All @@ -39,7 +39,7 @@ type Server struct {
func NewServer(cnf *config.Config, brokerServer brokersiface.Broker, backendServer backendsiface.Backend, lock lockiface.Lock) *Server {
srv := &Server{
config: cnf,
registeredTasks: make(map[string]interface{}),
registeredTasks: new(sync.Map),
broker: brokerServer,
backend: backendServer,
lock: lock,
Expand All @@ -56,7 +56,7 @@ func NewServer(cnf *config.Config, brokerServer brokersiface.Broker, backendServ
func NewServerWithBrokerBackendLock(cnf *config.Config, brokerServer brokersiface.Broker, backendServer backendsiface.Backend, lock lockiface.Lock) *Server {
srv := &Server{
config: cnf,
registeredTasks: map[string]interface{}{},
registeredTasks: new(sync.Map),
broker: brokerServer,
backend: backendServer,
lock: lock,
Expand Down Expand Up @@ -131,7 +131,9 @@ func (server *Server) RegisterTasks(namedTaskFuncs map[string]interface{}) error
return err
}
}
server.registeredTasks = namedTaskFuncs
for k, v := range namedTaskFuncs {
server.registeredTasks.Store(k, v)
}
server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
return nil
}
Expand All @@ -141,20 +143,20 @@ func (server *Server) RegisterTask(name string, taskFunc interface{}) error {
if err := tasks.ValidateTask(taskFunc); err != nil {
return err
}
server.registeredTasks[name] = taskFunc
server.registeredTasks.Store(name, taskFunc)
server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
return nil
}

// IsTaskRegistered returns true if the task name is registered with this broker
func (server *Server) IsTaskRegistered(name string) bool {
_, ok := server.registeredTasks[name]
_, ok := server.registeredTasks.Load(name)
return ok
}

// GetRegisteredTask returns registered task by name
func (server *Server) GetRegisteredTask(name string) (interface{}, error) {
taskFunc, ok := server.registeredTasks[name]
taskFunc, ok := server.registeredTasks.Load(name)
if !ok {
return nil, fmt.Errorf("Task not registered error: %s", name)
}
Expand Down Expand Up @@ -328,12 +330,12 @@ func (server *Server) SendChord(chord *tasks.Chord, sendConcurrency int) (*resul

// GetRegisteredTaskNames returns slice of registered task names
func (server *Server) GetRegisteredTaskNames() []string {
taskNames := make([]string, len(server.registeredTasks))
var i = 0
for name := range server.registeredTasks {
taskNames[i] = name
i++
}
taskNames := make([]string, 0)

server.registeredTasks.Range(func(key, value interface{}) bool {
taskNames = append(taskNames, key.(string))
return true
})
return taskNames
}

Expand Down

0 comments on commit 4554c7e

Please sign in to comment.