From 2c041280b260f5c595c49a5919b01ef380c58796 Mon Sep 17 00:00:00 2001 From: jeff lee Date: Mon, 21 Jun 2021 18:41:18 +0800 Subject: [PATCH] add PublishToLocal --- integration-tests/amqp_amqp_test.go | 4 +- integration-tests/amqp_memcache_test.go | 4 +- integration-tests/amqp_mongodb_test.go | 4 +- integration-tests/amqp_redis_test.go | 4 +- integration-tests/gcppubsub_redis_test.go | 5 +- integration-tests/redis_memcache_test.go | 4 +- integration-tests/redis_mongodb_test.go | 4 +- integration-tests/redis_redis_test.go | 4 +- integration-tests/redis_socket_test.go | 4 +- integration-tests/sqs_amqp_test.go | 4 +- integration-tests/sqs_mongodb_test.go | 2 + integration-tests/suite_test.go | 98 +++++++++++++++++++++++ v1/brokers/amqp/amqp.go | 83 +++++++++++++------ v1/brokers/amqp/amqp_concurrence_test.go | 11 +-- v1/brokers/eager/eager.go | 5 ++ v1/brokers/gcppubsub/gcp_pubsub.go | 5 ++ v1/brokers/iface/interfaces.go | 3 + v1/brokers/redis/goredis.go | 26 +++++- v1/brokers/redis/redis.go | 26 +++++- v1/brokers/sqs/sqs.go | 40 +++++++-- v1/server_test.go | 4 +- v1/worker.go | 2 +- v1/worker_test.go | 2 +- 23 files changed, 294 insertions(+), 54 deletions(-) diff --git a/integration-tests/amqp_amqp_test.go b/integration-tests/amqp_amqp_test.go index 93295d148..d7890455e 100644 --- a/integration-tests/amqp_amqp_test.go +++ b/integration-tests/amqp_amqp_test.go @@ -43,7 +43,9 @@ func TestAmqpAmqp(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/amqp_memcache_test.go b/integration-tests/amqp_memcache_test.go index f6574d8c9..ad7e76ff1 100644 --- a/integration-tests/amqp_memcache_test.go +++ b/integration-tests/amqp_memcache_test.go @@ -39,7 +39,9 @@ func TestAmqpMemcache(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/amqp_mongodb_test.go b/integration-tests/amqp_mongodb_test.go index 9d872f445..cd60817e8 100644 --- a/integration-tests/amqp_mongodb_test.go +++ b/integration-tests/amqp_mongodb_test.go @@ -35,7 +35,9 @@ func TestAmqpMongodb(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/amqp_redis_test.go b/integration-tests/amqp_redis_test.go index 1ce4d26fa..2f4ba786b 100644 --- a/integration-tests/amqp_redis_test.go +++ b/integration-tests/amqp_redis_test.go @@ -34,7 +34,9 @@ func TestAmqpRedis(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/gcppubsub_redis_test.go b/integration-tests/gcppubsub_redis_test.go index 423abd739..b045e9ea4 100644 --- a/integration-tests/gcppubsub_redis_test.go +++ b/integration-tests/gcppubsub_redis_test.go @@ -95,7 +95,10 @@ func TestGCPPubSubRedis(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + // not supported + // testPubslishToLocal(server, t) } diff --git a/integration-tests/redis_memcache_test.go b/integration-tests/redis_memcache_test.go index 68870456c..beaace837 100644 --- a/integration-tests/redis_memcache_test.go +++ b/integration-tests/redis_memcache_test.go @@ -28,7 +28,9 @@ func TestRedisMemcache(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/redis_mongodb_test.go b/integration-tests/redis_mongodb_test.go index b92ed5009..e768f2f42 100644 --- a/integration-tests/redis_mongodb_test.go +++ b/integration-tests/redis_mongodb_test.go @@ -29,7 +29,9 @@ func TestRedisMongodb(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/redis_redis_test.go b/integration-tests/redis_redis_test.go index 58b654f58..04fdcc3f9 100644 --- a/integration-tests/redis_redis_test.go +++ b/integration-tests/redis_redis_test.go @@ -26,9 +26,11 @@ func TestRedisRedis_Redigo(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } func TestRedisRedisNormalTaskPollPeriodLessThan1SecondShouldNotFailNextTask(t *testing.T) { diff --git a/integration-tests/redis_socket_test.go b/integration-tests/redis_socket_test.go index f46dca2ab..0123dc947 100644 --- a/integration-tests/redis_socket_test.go +++ b/integration-tests/redis_socket_test.go @@ -24,7 +24,9 @@ func TestRedisSocket(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/sqs_amqp_test.go b/integration-tests/sqs_amqp_test.go index 552ef7dbe..6ef320baa 100644 --- a/integration-tests/sqs_amqp_test.go +++ b/integration-tests/sqs_amqp_test.go @@ -34,7 +34,9 @@ func TestSQSAmqp(t *testing.T) { }) worker := server.(*machinery.Server).NewWorker("test_worker", 0) - defer worker.Quit() go worker.Launch() testAll(server, t) + worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/sqs_mongodb_test.go b/integration-tests/sqs_mongodb_test.go index da1237249..c30d6e36f 100644 --- a/integration-tests/sqs_mongodb_test.go +++ b/integration-tests/sqs_mongodb_test.go @@ -31,4 +31,6 @@ func TestSQSMongodb(t *testing.T) { go worker.Launch() testAll(server, t) worker.Quit() + + testPubslishToLocal(server, t) } diff --git a/integration-tests/suite_test.go b/integration-tests/suite_test.go index 0e99d1cf4..78ffc1b4d 100644 --- a/integration-tests/suite_test.go +++ b/integration-tests/suite_test.go @@ -3,12 +3,15 @@ package integration_test import ( "context" "errors" + "fmt" "log" "reflect" "sort" + "strings" "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/RichardKnop/machinery/v1" @@ -29,6 +32,7 @@ type Server interface { GetBroker() brokersiface.Broker GetConfig() *config.Config RegisterTasks(namedTaskFuncs map[string]interface{}) error + RegisterTask(name string, taskFunc interface{}) error SendTaskWithContext(ctx context.Context, signature *tasks.Signature) (*result.AsyncResult, error) SendTask(signature *tasks.Signature) (*result.AsyncResult, error) SendChainWithContext(ctx context.Context, chain *tasks.Chain) (*result.ChainAsyncResult, error) @@ -37,6 +41,7 @@ type Server interface { SendGroup(group *tasks.Group, sendConcurrency int) ([]*result.AsyncResult, error) SendChordWithContext(ctx context.Context, chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) SendChord(chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) + NewWorker(consumerTag string, concurrency int) *machinery.Worker } func testAll(server Server, t *testing.T) { @@ -340,6 +345,99 @@ func testDelay(server Server, t *testing.T) { } } +func testPubslishToLocal(oldServer Server, t *testing.T) { + tag := "health_check_tag" + config := oldServer.GetConfig() + config.DefaultQueue = tag + if config.AMQP != nil { + config.AMQP.BindingKey = tag + } + server, err := machinery.NewServer(config) + if err != nil { + t.Fatal(err) + } + + healthCheckTaskName := "health-check" + healthCheckCompleteChan := make(chan string, 1) + // RegisterTask health check task + err = server.RegisterTask(healthCheckTaskName, func(healthCheckUUID string) error { + select { + case healthCheckCompleteChan <- healthCheckUUID: // success and send uuid + return nil + case <-time.After(5 * time.Second): + return fmt.Errorf("send health check result error: %v", healthCheckUUID) + } + }) + assert.Nil(t, err) + + // start worker with concurrency 1 + worker := server.NewWorker(tag, 1) + worker.Queue = tag + go worker.Launch() + time.Sleep(1 * time.Second) // ensure worker start + + // check health: send message to local worker; wait `taskExecutionTimeout` until the task is completed + checkHealth := func(consumerTag string, taskExecutionTimeout time.Duration) error { + // clear channel + select { + case <-healthCheckCompleteChan: + default: + } + + broker := server.GetBroker() + healthCheckUUID, err := uuid.NewUUID() + if err != nil { + return err + } + if err := broker.PublishToLocal(consumerTag, &tasks.Signature{ + UUID: healthCheckUUID.String(), + Name: healthCheckTaskName, + Args: []tasks.Arg{ + {Type: "string", Value: healthCheckUUID.String()}, + }, + }, 5*time.Second); err != nil { + return err + } + + // wait for task execution success + select { + case successUUID := <-healthCheckCompleteChan: + if successUUID == healthCheckUUID.String() { + return nil + } + case <-time.After(taskExecutionTimeout): + } + return fmt.Errorf("health check execution fail: %v", healthCheckUUID.String()) + } + // trigger `checkHealth` + err = checkHealth(tag, 5*time.Second) + assert.Nil(t, err) + + // Simulation of worker being stuck + if err := server.RegisterTask("sleep-ten-seconds", func() error { + time.Sleep(10 * time.Second) + return nil + }); err != nil { + t.Fatal(err) + } + if _, err = server.SendTask(&tasks.Signature{ + Name: "sleep-ten-seconds", + }); err != nil { + t.Fatal(err) + } + time.Sleep(3 * time.Second) // ensure sleep-ten-seconds running + // checkHealth fail + err = checkHealth(tag, 5*time.Second) + assert.True(t, strings.HasPrefix(err.Error(), "health check execution fail: ")) + time.Sleep(6 * time.Second) // ensure queue is empty and last health check task executed + // checkHealth success + err = checkHealth(tag, 5*time.Second) + assert.Nil(t, err) + + // stop worker + worker.Quit() +} + func registerTestTasks(server Server) { tasks := map[string]interface{}{ diff --git a/v1/brokers/amqp/amqp.go b/v1/brokers/amqp/amqp.go index 5edcb81cf..1a6c25380 100644 --- a/v1/brokers/amqp/amqp.go +++ b/v1/brokers/amqp/amqp.go @@ -34,13 +34,15 @@ type Broker struct { common.AMQPConnector processingWG sync.WaitGroup // use wait group to make sure task processing completes on interrupt signal - connections map[string]*AMQPConnection - connectionsMutex sync.RWMutex + connections map[string]*AMQPConnection + connectionsMutex sync.RWMutex + localDeliveries map[string]chan amqp.Delivery + localDeliveriesMutex sync.RWMutex } // New creates new Broker instance func New(cnf *config.Config) iface.Broker { - return &Broker{Broker: common.NewBroker(cnf), AMQPConnector: common.AMQPConnector{}, connections: make(map[string]*AMQPConnection)} + return &Broker{Broker: common.NewBroker(cnf), AMQPConnector: common.AMQPConnector{}, connections: make(map[string]*AMQPConnection), localDeliveries: make(map[string]chan amqp.Delivery)} } // StartConsuming enters a loop and waits for incoming messages @@ -92,10 +94,14 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess if err != nil { return b.GetRetry(), fmt.Errorf("Queue consume error: %s", err) } + localDeliveries := make(chan amqp.Delivery, 1) + b.localDeliveriesMutex.Lock() + b.localDeliveries[consumerTag] = localDeliveries + b.localDeliveriesMutex.Unlock() log.INFO.Print("[*] Waiting for messages. To exit press CTRL+C") - if err := b.consume(deliveries, concurrency, taskProcessor, amqpCloseChan); err != nil { + if err := b.consume(deliveries, concurrency, taskProcessor, amqpCloseChan, localDeliveries); err != nil { return b.GetRetry(), err } @@ -251,7 +257,7 @@ func (b *Broker) Publish(ctx context.Context, signature *tasks.Signature) error // consume takes delivered messages from the channel and manages a worker pool // to process tasks concurrently -func (b *Broker) consume(deliveries <-chan amqp.Delivery, concurrency int, taskProcessor iface.TaskProcessor, amqpCloseChan <-chan *amqp.Error) error { +func (b *Broker) consume(deliveries <-chan amqp.Delivery, concurrency int, taskProcessor iface.TaskProcessor, amqpCloseChan <-chan *amqp.Error, healthCheckDeliveries chan amqp.Delivery) error { pool := make(chan struct{}, concurrency) // initialize worker pool with maxWorkers workers @@ -266,6 +272,30 @@ func (b *Broker) consume(deliveries <-chan amqp.Delivery, concurrency int, taskP // a worker, that is, it avoids a possible deadlock errorsChan := make(chan error, 1) + consumeDelivery := func(d amqp.Delivery) { + if concurrency > 0 { + // get worker from pool (blocks until one is available) + <-pool + } + + b.processingWG.Add(1) + + // Consume the task inside a gotourine so multiple tasks + // can be processed concurrently + go func() { + if err := b.consumeOne(d, taskProcessor, true); err != nil { + errorsChan <- err + } + + b.processingWG.Done() + + if concurrency > 0 { + // give worker back to pool + pool <- struct{}{} + } + }() + } + for { select { case amqpErr := <-amqpCloseChan: @@ -273,29 +303,11 @@ func (b *Broker) consume(deliveries <-chan amqp.Delivery, concurrency int, taskP case err := <-errorsChan: return err case d := <-deliveries: - if concurrency > 0 { - // get worker from pool (blocks until one is available) - <-pool - } - - b.processingWG.Add(1) - - // Consume the task inside a gotourine so multiple tasks - // can be processed concurrently - go func() { - if err := b.consumeOne(d, taskProcessor, true); err != nil { - errorsChan <- err - } - - b.processingWG.Done() - - if concurrency > 0 { - // give worker back to pool - pool <- struct{}{} - } - }() + consumeDelivery(d) case <-b.GetStopChan(): return nil + case d := <-healthCheckDeliveries: + consumeDelivery(d) } } } @@ -490,3 +502,22 @@ func (b *Broker) GetPendingTasks(queue string) ([]*tasks.Signature, error) { return dumper.Signatures, nil } + +func (b *Broker) PublishToLocal(consumerTag string, sig *tasks.Signature, blockTimeout time.Duration) error { + b.localDeliveriesMutex.RLock() + deliveries, ok := b.localDeliveries[consumerTag] + b.localDeliveriesMutex.RUnlock() + if !ok { + return fmt.Errorf("no such consumerTag: %v", consumerTag) + } + msg, err := json.Marshal(sig) + if err != nil { + return fmt.Errorf("JSON marshal error: %s", err) + } + select { + case deliveries <- amqp.Delivery{Body: msg}: + return nil + case <-time.After(blockTimeout): + return fmt.Errorf("health check: %v queue is full.", consumerTag) + } +} diff --git a/v1/brokers/amqp/amqp_concurrence_test.go b/v1/brokers/amqp/amqp_concurrence_test.go index 5522a8a38..31bc00d9f 100644 --- a/v1/brokers/amqp/amqp_concurrence_test.go +++ b/v1/brokers/amqp/amqp_concurrence_test.go @@ -26,10 +26,11 @@ func (_ doNothingProcessor) PreConsumeHandler() bool { func TestConsume(t *testing.T) { var ( - iBroker iface.Broker - deliveries = make(chan amqp.Delivery, 3) - closeChan chan *amqp.Error - processor doNothingProcessor + iBroker iface.Broker + deliveries = make(chan amqp.Delivery, 3) + localDeliveries = make(chan amqp.Delivery) + closeChan chan *amqp.Error + processor doNothingProcessor ) t.Run("with deliveries more than the number of concurrency", func(t *testing.T) { @@ -45,7 +46,7 @@ func TestConsume(t *testing.T) { }() go func() { - err := broker.consume(deliveries, 2, processor, closeChan) + err := broker.consume(deliveries, 2, processor, closeChan, localDeliveries) if err != nil { errChan <- err } diff --git a/v1/brokers/eager/eager.go b/v1/brokers/eager/eager.go index 3726b70d9..858bcba42 100644 --- a/v1/brokers/eager/eager.go +++ b/v1/brokers/eager/eager.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "time" "github.com/RichardKnop/machinery/v1/brokers/iface" "github.com/RichardKnop/machinery/v1/common" @@ -66,3 +67,7 @@ func (eagerBroker *Broker) Publish(ctx context.Context, task *tasks.Signature) e func (eagerBroker *Broker) AssignWorker(w iface.TaskProcessor) { eagerBroker.worker = w } + +func (b *Broker) PublishToLocal(consumerTag string, sig *tasks.Signature, blockTimeout time.Duration) error { + return b.Publish(context.Background(), sig) +} diff --git a/v1/brokers/gcppubsub/gcp_pubsub.go b/v1/brokers/gcppubsub/gcp_pubsub.go index 5e374a9b3..4f65586b0 100644 --- a/v1/brokers/gcppubsub/gcp_pubsub.go +++ b/v1/brokers/gcppubsub/gcp_pubsub.go @@ -194,3 +194,8 @@ func (b *Broker) consumeOne(delivery *pubsub.Message, taskProcessor iface.TaskPr // Call Ack() after successfully consuming and processing the message delivery.Ack() } + +// not supported +func (b *Broker) PublishToLocal(consumerTag string, sig *tasks.Signature, blockTimeout time.Duration) error { + return fmt.Errorf("gcp pubsub not support PublishToLocal") +} diff --git a/v1/brokers/iface/interfaces.go b/v1/brokers/iface/interfaces.go index c15e8f720..f77f9cfa5 100644 --- a/v1/brokers/iface/interfaces.go +++ b/v1/brokers/iface/interfaces.go @@ -2,6 +2,7 @@ package iface import ( "context" + "time" "github.com/RichardKnop/machinery/v1/config" "github.com/RichardKnop/machinery/v1/tasks" @@ -18,6 +19,8 @@ type Broker interface { GetPendingTasks(queue string) ([]*tasks.Signature, error) GetDelayedTasks() ([]*tasks.Signature, error) AdjustRoutingKey(s *tasks.Signature) + // send messages to local workers + PublishToLocal(consumerTag string, sig *tasks.Signature, t time.Duration) error } // TaskProcessor - can process a delivered task diff --git a/v1/brokers/redis/goredis.go b/v1/brokers/redis/goredis.go index feed029d2..d384ccb73 100644 --- a/v1/brokers/redis/goredis.go +++ b/v1/brokers/redis/goredis.go @@ -34,11 +34,13 @@ type BrokerGR struct { redsync *redsync.Redsync redisOnce sync.Once redisDelayedTasksKey string + deliveriesMap map[string]chan []byte + deliveriesMapMutex sync.RWMutex } // NewGR creates new Broker instance func NewGR(cnf *config.Config, addrs []string, db int) iface.Broker { - b := &BrokerGR{Broker: common.NewBroker(cnf)} + b := &BrokerGR{Broker: common.NewBroker(cnf), deliveriesMap: make(map[string]chan []byte)} var password string parts := strings.Split(addrs[0], "@") @@ -95,6 +97,9 @@ func (b *BrokerGR) StartConsuming(consumerTag string, concurrency int, taskProce // Channel to which we will push tasks ready for processing by worker deliveries := make(chan []byte, concurrency) pool := make(chan struct{}, concurrency) + b.deliveriesMapMutex.Lock() + b.deliveriesMap[consumerTag] = deliveries + b.deliveriesMapMutex.Unlock() // initialize worker pool with maxWorkers workers for i := 0; i < concurrency; i++ { @@ -420,3 +425,22 @@ func getQueueGR(config *config.Config, taskProcessor iface.TaskProcessor) string } return customQueue } + +func (b *BrokerGR) PublishToLocal(consumerTag string, sig *tasks.Signature, blockTimeout time.Duration) error { + b.deliveriesMapMutex.RLock() + deliveries, ok := b.deliveriesMap[consumerTag] + b.deliveriesMapMutex.RUnlock() + if !ok { + return fmt.Errorf("no such consumerTag: %v", consumerTag) + } + msg, err := json.Marshal(sig) + if err != nil { + return fmt.Errorf("JSON marshal error: %s", err) + } + select { + case deliveries <- msg: + return nil + case <-time.After(blockTimeout): + return fmt.Errorf("health check: %v queue is full.", consumerTag) + } +} diff --git a/v1/brokers/redis/redis.go b/v1/brokers/redis/redis.go index 1683f13e6..552c4c57f 100644 --- a/v1/brokers/redis/redis.go +++ b/v1/brokers/redis/redis.go @@ -40,11 +40,13 @@ type Broker struct { redsync *redsync.Redsync redisOnce sync.Once redisDelayedTasksKey string + deliveriesMap map[string]chan []byte + deliveriesMapMutex sync.RWMutex } // New creates new Broker instance func New(cnf *config.Config, host, password, socketPath string, db int) iface.Broker { - b := &Broker{Broker: common.NewBroker(cnf)} + b := &Broker{Broker: common.NewBroker(cnf), deliveriesMap: make(map[string]chan []byte)} b.host = host b.db = db b.password = password @@ -91,6 +93,9 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess // Channel to which we will push tasks ready for processing by worker deliveries := make(chan []byte, concurrency) pool := make(chan struct{}, concurrency) + b.deliveriesMapMutex.Lock() + b.deliveriesMap[consumerTag] = deliveries + b.deliveriesMapMutex.Unlock() // initialize worker pool with maxWorkers workers for i := 0; i < concurrency; i++ { @@ -485,3 +490,22 @@ func (b *Broker) requeueMessage(delivery []byte, taskProcessor iface.TaskProcess defer conn.Close() conn.Do("RPUSH", getQueue(b.GetConfig(), taskProcessor), delivery) } + +func (b *Broker) PublishToLocal(consumerTag string, sig *tasks.Signature, blockTimeout time.Duration) error { + b.deliveriesMapMutex.RLock() + deliveries, ok := b.deliveriesMap[consumerTag] + b.deliveriesMapMutex.RUnlock() + if !ok { + return fmt.Errorf("no such consumerTag: %v", consumerTag) + } + msg, err := json.Marshal(sig) + if err != nil { + return fmt.Errorf("JSON marshal error: %s", err) + } + select { + case deliveries <- msg: + return nil + case <-time.After(blockTimeout): + return fmt.Errorf("health check: %v queue is full.", consumerTag) + } +} diff --git a/v1/brokers/sqs/sqs.go b/v1/brokers/sqs/sqs.go index f811e16e5..dcb7a3a1c 100644 --- a/v1/brokers/sqs/sqs.go +++ b/v1/brokers/sqs/sqs.go @@ -30,17 +30,19 @@ const ( // There are examples on: https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/sqs-example-create-queue.html type Broker struct { common.Broker - processingWG sync.WaitGroup // use wait group to make sure task processing completes on interrupt signal - receivingWG sync.WaitGroup - stopReceivingChan chan int - sess *session.Session - service sqsiface.SQSAPI - queueUrl *string + processingWG sync.WaitGroup // use wait group to make sure task processing completes on interrupt signal + receivingWG sync.WaitGroup + stopReceivingChan chan int + sess *session.Session + service sqsiface.SQSAPI + queueUrl *string + deliveriesMap map[string]chan *awssqs.ReceiveMessageOutput + deliveriesMapMutex sync.RWMutex } // New creates new Broker instance func New(cnf *config.Config) iface.Broker { - b := &Broker{Broker: common.NewBroker(cnf)} + b := &Broker{Broker: common.NewBroker(cnf), deliveriesMap: make(map[string]chan *awssqs.ReceiveMessageOutput)} if cnf.SQS != nil && cnf.SQS.Client != nil { // Use provided *SQS client b.service = cnf.SQS.Client @@ -66,6 +68,9 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess deliveries := make(chan *awssqs.ReceiveMessageOutput, concurrency) pool := make(chan struct{}, concurrency) + b.deliveriesMapMutex.Lock() + b.deliveriesMap[consumerTag] = deliveries + b.deliveriesMapMutex.Unlock() // initialize worker pool with maxWorkers workers for i := 0; i < concurrency; i++ { @@ -366,3 +371,24 @@ func (b *Broker) getQueueURL(taskProcessor iface.TaskProcessor) *string { return aws.String(b.GetConfig().Broker + "/" + queueName) } + +func (b *Broker) PublishToLocal(consumerTag string, sig *tasks.Signature, blockTimeout time.Duration) error { + b.deliveriesMapMutex.RLock() + deliveries, ok := b.deliveriesMap[consumerTag] + b.deliveriesMapMutex.RUnlock() + if !ok { + return fmt.Errorf("no such consumerTag: %v", consumerTag) + } + msg, err := json.Marshal(sig) + if err != nil { + return fmt.Errorf("JSON marshal error: %s", err) + } + sqsMsg := awssqs.Message{} + sqsMsg.SetBody(string(msg)) + select { + case deliveries <- &awssqs.ReceiveMessageOutput{Messages: []*awssqs.Message{&sqsMsg}}: + return nil + case <-time.After(blockTimeout): + return fmt.Errorf("health check: %v queue is full.", consumerTag) + } +} diff --git a/v1/server_test.go b/v1/server_test.go index cce786df3..f48515891 100644 --- a/v1/server_test.go +++ b/v1/server_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" - + "github.com/RichardKnop/machinery/v1" "github.com/RichardKnop/machinery/v1/config" ) @@ -37,7 +37,7 @@ func TestRegisterTaskInRaceCondition(t *testing.T) { t.Parallel() server := getTestServer(t) - for i:=0; i<10; i++ { + for i := 0; i < 10; i++ { go func() { err := server.RegisterTask("test_task", func() error { return nil }) assert.NoError(t, err) diff --git a/v1/worker.go b/v1/worker.go index ef6b10bd8..b456b6c54 100644 --- a/v1/worker.go +++ b/v1/worker.go @@ -11,7 +11,7 @@ import ( "time" "github.com/opentracing/opentracing-go" - + "github.com/RichardKnop/machinery/v1/backends/amqp" "github.com/RichardKnop/machinery/v1/brokers/errs" "github.com/RichardKnop/machinery/v1/log" diff --git a/v1/worker_test.go b/v1/worker_test.go index 1a02084ba..bbdee5b3e 100644 --- a/v1/worker_test.go +++ b/v1/worker_test.go @@ -18,7 +18,7 @@ func TestRedactURL(t *testing.T) { func TestPreConsumeHandler(t *testing.T) { t.Parallel() - + worker := &machinery.Worker{} worker.SetPreConsumeHandler(SamplePreConsumeHandler)