Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWS SQS Visibility Heartbeat Feature #830

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions v2/brokers/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor
return errors.New("received empty message, the delivery is " + delivery.GoString())
}

if b.GetConfig().SQS.VisibilityHeartBeat {
notify := make(chan struct{})
defer close(notify)

b.visibilityHeartbeat(delivery, notify)
}

sig := new(tasks.Signature)
decoder := json.NewDecoder(strings.NewReader(*delivery.Messages[0].Body))
decoder.UseNumber()
Expand All @@ -219,15 +226,17 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor
// and leave the message in the queue
if !b.IsTaskRegistered(sig.Name) {
if sig.IgnoreWhenTaskNotRegistered {
b.deleteOne(delivery)
if err := b.deleteOne(delivery); err != nil {
log.ERROR.Printf("error when deleting the delivery. delivery is %v, Error=%s", delivery, err)
}
}
return fmt.Errorf("task %s is not registered", sig.Name)
}

err := taskProcessor.Process(sig)
if err != nil {
// stop task deletion in case we want to send messages to dlq in sqs
if err == errs.ErrStopTaskDeletion {
if errors.Is(err, errs.ErrStopTaskDeletion) {
return nil
}
return err
Expand Down Expand Up @@ -270,9 +279,8 @@ func (b *Broker) receiveMessage(qURL *string) (*awssqs.ReceiveMessageOutput, err
if b.GetConfig().SQS != nil {
waitTimeSeconds = b.GetConfig().SQS.WaitTimeSeconds
visibilityTimeout = b.GetConfig().SQS.VisibilityTimeout
} else {
waitTimeSeconds = 0
}

input := &awssqs.ReceiveMessageInput{
AttributeNames: []*string{
aws.String(awssqs.MessageSystemAttributeNameSentTimestamp),
Expand Down Expand Up @@ -350,6 +358,40 @@ func (b *Broker) continueReceivingMessages(qURL *string, deliveries chan *awssqs
return true, nil
}

// visibilityHeartbeat is a method that sends a heartbeat signal to AWS SQS to keep a message invisible to other consumers while being processed.
func (b *Broker) visibilityHeartbeat(delivery *awssqs.ReceiveMessageOutput, notify <-chan struct{}) {
if b.GetConfig().SQS.VisibilityTimeout == nil || *b.GetConfig().SQS.VisibilityTimeout == 0 {
return
}

ticker := time.NewTicker(time.Duration(*b.GetConfig().SQS.VisibilityTimeout) * 500 * time.Millisecond)

go func() {
for {
select {
case <-notify:
ticker.Stop()

return
case <-b.stopReceivingChan:
ticker.Stop()

return
case <-ticker.C:
// Extend the delivery visibility timeout
_, err := b.service.ChangeMessageVisibility(&awssqs.ChangeMessageVisibilityInput{
QueueUrl: b.defaultQueueURL(),
ReceiptHandle: delivery.Messages[0].ReceiptHandle,
VisibilityTimeout: aws.Int64(int64(*b.GetConfig().SQS.VisibilityTimeout)),
})
if err != nil {
log.ERROR.Printf("Error when changing delivery visibility: %v", err)
}
}
}
}()
}

// stopReceiving is a method sending a signal to stopReceivingChan
func (b *Broker) stopReceiving() {
// Stop the receiving goroutine
Expand Down
20 changes: 12 additions & 8 deletions v2/brokers/sqs/sqs_export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ import (
"os"
"sync"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"

"github.com/RichardKnop/machinery/v2/brokers/iface"
"github.com/RichardKnop/machinery/v2/common"
"github.com/RichardKnop/machinery/v2/config"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
awssqs "github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
)

var (
Expand Down Expand Up @@ -107,17 +105,23 @@ func NewTestConfig() *config.Config {
DefaultQueue: "test_queue",
ResultBackend: fmt.Sprintf("redis://%v", redisURL),
Lock: fmt.Sprintf("redis://%v", redisURL),
SQS: &config.SQSConfig{
VisibilityTimeout: aws.Int(30),
},
}
}

func NewTestBroker() *Broker {
func NewTestBroker(cnf *config.Config) *Broker {

cnf := NewTestConfig()
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

svc := new(FakeSQS)
var svc sqsiface.SQSAPI = new(FakeSQS)

if cnf.SQS.Client != nil {
svc = cnf.SQS.Client
}
return &Broker{
Broker: common.NewBroker(cnf),
sess: sess,
Expand Down
Loading