Skip to content

Commit

Permalink
[DynamoDB] Support ResultsExpiresIn (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
gow authored Apr 14, 2020
1 parent f259ace commit 107e5de
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ dynamodb:
```
If these tables are not found, an fatal error would be thrown.

If you wish to expire the records, you can configure the `TTL` field in AWS admin for these tables. The `TTL` field is set based on the `ResultsExpireIn` value in the Server's config. See https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/howitworks-ttl.html for more information.

### Custom Logger

You can define a custom logger by implementing the following interface:
Expand Down
27 changes: 27 additions & 0 deletions v1/backends/dynamodb/dynamodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func (b *Backend) InitGroup(groupUUID string, taskUUIDs []string) error {
GroupUUID: groupUUID,
TaskUUIDs: taskUUIDs,
CreatedAt: time.Now().UTC(),
TTL: b.getExpirationTime(),
}
av, err := dynamodbattribute.MarshalMap(meta)
if err != nil {
Expand Down Expand Up @@ -164,12 +165,14 @@ func (b *Backend) SetStateRetry(signature *tasks.Signature) error {
// SetStateSuccess ...
func (b *Backend) SetStateSuccess(signature *tasks.Signature, results []*tasks.TaskResult) error {
taskState := tasks.NewSuccessTaskState(signature, results)
taskState.TTL = b.getExpirationTime()
return b.setTaskState(taskState)
}

// SetStateFailure ...
func (b *Backend) SetStateFailure(signature *tasks.Signature, err string) error {
taskState := tasks.NewFailureTaskState(signature, err)
taskState.TTL = b.getExpirationTime()
return b.updateToFailureStateWithError(taskState)
}

Expand Down Expand Up @@ -367,6 +370,13 @@ func (b *Backend) setTaskState(taskState *tasks.TaskState) error {
}
exp += ", #C = :c"
}
if taskState.TTL > 0 {
expAttributeNames["#T"] = aws.String("TTL")
expAttributeValues[":t"] = &dynamodb.AttributeValue{
N: aws.String(fmt.Sprintf("%d", taskState.TTL)),
}
exp += ", #T = :t"
}
if taskState.Results != nil && len(taskState.Results) != 0 {
expAttributeNames["#R"] = aws.String("Results")
var results []*dynamodb.AttributeValue
Expand Down Expand Up @@ -447,6 +457,14 @@ func (b *Backend) updateToFailureStateWithError(taskState *tasks.TaskState) erro
UpdateExpression: aws.String("SET #S = :s, #E = :e"),
}

if taskState.TTL > 0 {
input.ExpressionAttributeNames["#T"] = aws.String("TTL")
input.ExpressionAttributeValues[":t"] = &dynamodb.AttributeValue{
N: aws.String(fmt.Sprintf("%d", taskState.TTL)),
}
input.UpdateExpression = aws.String(aws.StringValue(input.UpdateExpression) + ", #T = :t")
}

_, err := b.client.UpdateItem(input)

if err != nil {
Expand Down Expand Up @@ -511,3 +529,12 @@ func (b *Backend) tableExists(tableName string, tableNames []*string) bool {
}
return false
}

func (b *Backend) getExpirationTime() int64 {
expiresIn := b.GetConfig().ResultsExpireIn
if expiresIn == 0 {
// expire results after 1 hour by default
expiresIn = config.DefaultResultsExpireIn
}
return time.Now().Add(time.Second * time.Duration(expiresIn)).Unix()
}
17 changes: 15 additions & 2 deletions v1/backends/dynamodb/dynamodb_export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ var (

type TestDynamoDBClient struct {
dynamodbiface.DynamoDBAPI
PutItemOverride func(*dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error)
UpdateItemOverride func(*dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error)
}

func (t *TestDynamoDBClient) PutItem(*dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
func (t *TestDynamoDBClient) ResetOverrides() {
t.PutItemOverride = nil
t.UpdateItemOverride = nil
}

func (t *TestDynamoDBClient) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
if t.PutItemOverride != nil {
return t.PutItemOverride(input)
}
return &dynamodb.PutItemOutput{}, nil
}

Expand Down Expand Up @@ -104,7 +114,10 @@ func (t *TestDynamoDBClient) DeleteItem(*dynamodb.DeleteItemInput) (*dynamodb.De
return &dynamodb.DeleteItemOutput{}, nil
}

func (t *TestDynamoDBClient) UpdateItem(*dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) {
func (t *TestDynamoDBClient) UpdateItem(input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) {
if t.UpdateItemOverride != nil {
return t.UpdateItemOverride(input)
}
return &dynamodb.UpdateItemOutput{}, nil
}

Expand Down
144 changes: 144 additions & 0 deletions v1/backends/dynamodb/dynamodb_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package dynamodb_test

import (
"strconv"
"testing"
"time"

"github.com/RichardKnop/machinery/v1/backends/dynamodb"
"github.com/RichardKnop/machinery/v1/log"
Expand All @@ -22,11 +24,39 @@ func TestInitGroup(t *testing.T) {
groupUUID := "testGroupUUID"
taskUUIDs := []string{"testTaskUUID1", "testTaskUUID2", "testTaskUUID3"}
log.INFO.Println(dynamodb.TestDynamoDBBackend.GetConfig())

err := dynamodb.TestDynamoDBBackend.InitGroup(groupUUID, taskUUIDs)
assert.Nil(t, err)

err = dynamodb.TestErrDynamoDBBackend.InitGroup(groupUUID, taskUUIDs)
assert.NotNil(t, err)

// assert proper TTL value is set in InitGroup()
dynamodb.TestDynamoDBBackend.GetConfig().ResultsExpireIn = 3 * 3600 // results should expire after 3 hours
client := dynamodb.TestDynamoDBBackend.GetClient().(*dynamodb.TestDynamoDBClient)
// Override DynamoDB PutItem() behavior
var isPutItemCalled bool
client.PutItemOverride = func(input *awsdynamodb.PutItemInput) (*awsdynamodb.PutItemOutput, error) {
isPutItemCalled = true
assert.NotNil(t, input)

actualTTLStr := *input.Item["TTL"].N
expectedTTLTime := time.Now().Add(3 * time.Hour)
assertTTLValue(t, expectedTTLTime, actualTTLStr)

return &awsdynamodb.PutItemOutput{}, nil
}
err = dynamodb.TestDynamoDBBackend.InitGroup(groupUUID, taskUUIDs)
assert.Nil(t, err)
assert.True(t, isPutItemCalled)
client.ResetOverrides()
}

func assertTTLValue(t *testing.T, expectedTTLTime time.Time, actualEncodedTTLValue string) {
actualTTLTimestamp, err := strconv.ParseInt(actualEncodedTTLValue, 10, 64)
assert.Nil(t, err)
actualTTLTime := time.Unix(actualTTLTimestamp, 0)
assert.WithinDuration(t, expectedTTLTime, actualTTLTime, time.Second)
}

func TestGroupCompleted(t *testing.T) {
Expand Down Expand Up @@ -258,6 +288,120 @@ func TestPrivateFuncSetTaskState(t *testing.T) {
assert.Nil(t, err)
}

// verifyUpdateInput is a helper function to verify valid dynamoDB update input.
func verifyUpdateInput(t *testing.T, input *awsdynamodb.UpdateItemInput, expectedTaskID string, expectedState string, expectedTTLTime time.Time) {
assert.NotNil(t, input)

// verify task ID
assert.Equal(t, expectedTaskID, *input.Key["TaskUUID"].S)

// verify task state
assert.Equal(t, expectedState, *input.ExpressionAttributeValues[":s"].S)

// Verify TTL
if !expectedTTLTime.IsZero() {
actualTTLStr := *input.ExpressionAttributeValues[":t"].N
assertTTLValue(t, expectedTTLTime, actualTTLStr)
}
}

func TestSetStateSuccess(t *testing.T) {
signature := &tasks.Signature{UUID: "testTaskUUID"}

// assert correct task ID, state and TTL value is set in SetStateSuccess()
dynamodb.TestDynamoDBBackend.GetConfig().ResultsExpireIn = 3 * 3600 // results should expire after 3 hours
client := dynamodb.TestDynamoDBBackend.GetClient().(*dynamodb.TestDynamoDBClient)
// Override DynamoDB UpdateItem() behavior
var isUpdateItemCalled bool
client.UpdateItemOverride = func(input *awsdynamodb.UpdateItemInput) (*awsdynamodb.UpdateItemOutput, error) {
isUpdateItemCalled = true
verifyUpdateInput(t, input, signature.UUID, tasks.StateSuccess, time.Now().Add(3*time.Hour))
return &awsdynamodb.UpdateItemOutput{}, nil
}

err := dynamodb.TestDynamoDBBackend.SetStateSuccess(signature, nil)
assert.Nil(t, err)
assert.True(t, isUpdateItemCalled)
client.ResetOverrides()
}

func TestSetStateFailure(t *testing.T) {
signature := &tasks.Signature{UUID: "testTaskUUID"}

// assert correct task ID, state and TTL value is set in SetStateFailure()
dynamodb.TestDynamoDBBackend.GetConfig().ResultsExpireIn = 2 * 3600 // results should expire after 2 hours
client := dynamodb.TestDynamoDBBackend.GetClient().(*dynamodb.TestDynamoDBClient)
// Override DynamoDB UpdateItem() behavior
var isUpdateItemCalled bool
client.UpdateItemOverride = func(input *awsdynamodb.UpdateItemInput) (*awsdynamodb.UpdateItemOutput, error) {
isUpdateItemCalled = true
verifyUpdateInput(t, input, signature.UUID, tasks.StateFailure, time.Now().Add(2*time.Hour))
return &awsdynamodb.UpdateItemOutput{}, nil
}

err := dynamodb.TestDynamoDBBackend.SetStateFailure(signature, "Some error occurred")
assert.Nil(t, err)
assert.True(t, isUpdateItemCalled)
client.ResetOverrides()
}

func TestSetStateReceived(t *testing.T) {
signature := &tasks.Signature{UUID: "testTaskUUID"}

// assert correct task ID, state and *no* TTL value is set in SetStateReceived()
dynamodb.TestDynamoDBBackend.GetConfig().ResultsExpireIn = 2 * 3600 // results should expire after 2 hours (ignored for this state)
client := dynamodb.TestDynamoDBBackend.GetClient().(*dynamodb.TestDynamoDBClient)
var isUpdateItemCalled bool
client.UpdateItemOverride = func(input *awsdynamodb.UpdateItemInput) (*awsdynamodb.UpdateItemOutput, error) {
isUpdateItemCalled = true
verifyUpdateInput(t, input, signature.UUID, tasks.StateReceived, time.Time{})
return &awsdynamodb.UpdateItemOutput{}, nil
}

err := dynamodb.TestDynamoDBBackend.SetStateReceived(signature)
assert.Nil(t, err)
assert.True(t, isUpdateItemCalled)
client.ResetOverrides()
}

func TestSetStateStarted(t *testing.T) {
signature := &tasks.Signature{UUID: "testTaskUUID"}

// assert correct task ID, state and *no* TTL value is set in SetStateStarted()
dynamodb.TestDynamoDBBackend.GetConfig().ResultsExpireIn = 2 * 3600 // results should expire after 2 hours (ignored for this state)
client := dynamodb.TestDynamoDBBackend.GetClient().(*dynamodb.TestDynamoDBClient)
var isUpdateItemCalled bool
client.UpdateItemOverride = func(input *awsdynamodb.UpdateItemInput) (*awsdynamodb.UpdateItemOutput, error) {
isUpdateItemCalled = true
verifyUpdateInput(t, input, signature.UUID, tasks.StateStarted, time.Time{})
return &awsdynamodb.UpdateItemOutput{}, nil
}

err := dynamodb.TestDynamoDBBackend.SetStateStarted(signature)
assert.Nil(t, err)
assert.True(t, isUpdateItemCalled)
client.ResetOverrides()
}

func TestSetStateRetry(t *testing.T) {
signature := &tasks.Signature{UUID: "testTaskUUID"}

// assert correct task ID, state and *no* TTL value is set in SetStateStarted()
dynamodb.TestDynamoDBBackend.GetConfig().ResultsExpireIn = 2 * 3600 // results should expire after 2 hours (ignored for this state)
client := dynamodb.TestDynamoDBBackend.GetClient().(*dynamodb.TestDynamoDBClient)
var isUpdateItemCalled bool
client.UpdateItemOverride = func(input *awsdynamodb.UpdateItemInput) (*awsdynamodb.UpdateItemOutput, error) {
isUpdateItemCalled = true
verifyUpdateInput(t, input, signature.UUID, tasks.StateRetry, time.Time{})
return &awsdynamodb.UpdateItemOutput{}, nil
}

err := dynamodb.TestDynamoDBBackend.SetStateRetry(signature)
assert.Nil(t, err)
assert.True(t, isUpdateItemCalled)
client.ResetOverrides()
}

func TestPrivateFuncGetStates(t *testing.T) {
task1 := map[string]*awsdynamodb.AttributeValue{
"Error": {
Expand Down
2 changes: 2 additions & 0 deletions v1/tasks/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type TaskState struct {
Results []*TaskResult `bson:"results"`
Error string `bson:"error"`
CreatedAt time.Time `bson:"created_at"`
TTL int64 `bson:"ttl,omitempty"`
}

// GroupMeta stores useful metadata about tasks within the same group
Expand All @@ -36,6 +37,7 @@ type GroupMeta struct {
ChordTriggered bool `bson:"chord_triggered"`
Lock bool `bson:"lock"`
CreatedAt time.Time `bson:"created_at"`
TTL int64 `bson:"ttl,omitempty"`
}

// NewPendingTaskState ...
Expand Down

0 comments on commit 107e5de

Please sign in to comment.