diff --git a/publisher/kinesis/kinesis.go b/publisher/kinesis/kinesis.go index 7ff1cbfc..dc561899 100644 --- a/publisher/kinesis/kinesis.go +++ b/publisher/kinesis/kinesis.go @@ -20,8 +20,14 @@ import ( var globalCtx = context.Background() +type KinesisClient interface { + PutRecord(context.Context, *kinesis.PutRecordInput, ...func(*kinesis.Options)) (*kinesis.PutRecordOutput, error) + DescribeStreamSummary(context.Context, *kinesis.DescribeStreamSummaryInput, ...func(*kinesis.Options)) (*kinesis.DescribeStreamSummaryOutput, error) + CreateStream(context.Context, *kinesis.CreateStreamInput, ...func(*kinesis.Options)) (*kinesis.CreateStreamOutput, error) +} + type Publisher struct { - client *kinesis.Client + client KinesisClient streamLock sync.RWMutex streams map[string]bool diff --git a/publisher/kinesis/kinesis_integration_test.go b/publisher/kinesis/kinesis_integration_test.go new file mode 100644 index 00000000..2a9b98db --- /dev/null +++ b/publisher/kinesis/kinesis_integration_test.go @@ -0,0 +1,293 @@ +package kinesis_test + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + kinesis_sdk "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + "github.com/raystack/raccoon/logger" + pb "github.com/raystack/raccoon/proto" + "github.com/raystack/raccoon/publisher/kinesis" + "github.com/stretchr/testify/require" +) + +const ( + envLocalstackHost = "LOCALSTACK_HOST" +) + +type localstackProvider struct{} + +func (p *localstackProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: "test", + SecretAccessKey: "test", + }, nil +} + +func withLocalStack(host string) func(o *kinesis_sdk.Options) { + return func(o *kinesis_sdk.Options) { + o.BaseEndpoint = aws.String(host) + o.Credentials = &localstackProvider{} + } +} + +var ( + testEvent = &pb.Event{ + EventBytes: []byte("EVENT"), + Type: "click", + } +) + +func createStream(client *kinesis_sdk.Client, name string) (string, error) { + _, err := client.CreateStream( + context.Background(), + &kinesis_sdk.CreateStreamInput{ + StreamName: aws.String(name), + StreamModeDetails: &types.StreamModeDetails{ + StreamMode: types.StreamModeOnDemand, + }, + ShardCount: aws.Int32(1), + }, + ) + if err != nil { + return "", err + } + retries := 5 + for range retries { + stream, err := client.DescribeStreamSummary( + context.Background(), + &kinesis_sdk.DescribeStreamSummaryInput{ + StreamName: aws.String(name), + }, + ) + if err != nil { + return "", err + } + if stream.StreamDescriptionSummary.StreamStatus == types.StreamStatusActive { + return *stream.StreamDescriptionSummary.StreamARN, nil + } + time.Sleep(time.Second / 2) + } + return "", fmt.Errorf("timed out waiting for stream to get ready") +} + +func deleteStream(client *kinesis_sdk.Client, name string) error { + _, err := client.DeleteStream(context.Background(), &kinesis_sdk.DeleteStreamInput{ + StreamName: aws.String(name), + }) + if err != nil { + return err + } + + var errNotFound *types.ResourceNotFoundException + for !errors.As(err, &errNotFound) { + _, err = client.DescribeStreamSummary( + context.Background(), + &kinesis_sdk.DescribeStreamSummaryInput{ + StreamName: aws.String(name), + }, + ) + time.Sleep(time.Second / 2) + } + + return nil +} + +func getStreamMode(client *kinesis_sdk.Client, name string) (types.StreamMode, error) { + stream, err := client.DescribeStreamSummary( + context.Background(), + &kinesis_sdk.DescribeStreamSummaryInput{ + StreamName: aws.String(name), + }, + ) + if err != nil { + return "", err + } + return stream.StreamDescriptionSummary.StreamModeDetails.StreamMode, nil +} + +func readStream(client *kinesis_sdk.Client, arn string) ([][]byte, error) { + stream, err := client.DescribeStream( + context.Background(), + &kinesis_sdk.DescribeStreamInput{ + StreamARN: aws.String(arn), + }, + ) + if err != nil { + return nil, err + } + if len(stream.StreamDescription.Shards) == 0 { + return nil, fmt.Errorf("stream %q has no shards", arn) + } + iter, err := client.GetShardIterator( + context.Background(), + &kinesis_sdk.GetShardIteratorInput{ + ShardId: stream.StreamDescription.Shards[0].ShardId, + StreamARN: aws.String(arn), + ShardIteratorType: types.ShardIteratorTypeTrimHorizon, + }, + ) + if err != nil { + return nil, err + } + res, err := client.GetRecords( + context.Background(), + &kinesis_sdk.GetRecordsInput{ + StreamARN: aws.String(arn), + ShardIterator: iter.ShardIterator, + }, + ) + if err != nil { + return nil, err + } + if len(res.Records) == 0 { + return nil, fmt.Errorf("got empty response") + } + rv := [][]byte{} + for _, record := range res.Records { + rv = append(rv, record.Data) + } + return rv, nil +} + +func TestKinesisProducer(t *testing.T) { + localstackHost := os.Getenv(envLocalstackHost) + if strings.TrimSpace(localstackHost) == "" { + t.Errorf("cannot run tests because %s env variable is not set", envLocalstackHost) + return + } + cfg, err := config.LoadDefaultConfig(context.Background()) + require.NoError(t, err, "error loading aws config") + + client := kinesis_sdk.NewFromConfig(cfg, withLocalStack(localstackHost)) + + t.Run("should return an error if stream doesn't exist", func(t *testing.T) { + pub, err := kinesis.New(client) + require.NoError(t, err) + err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.Error(t, err) + }) + + t.Run("should return an error if an invalid stream mode is specified", func(t *testing.T) { + _, err := kinesis.New( + client, + kinesis.WithStreamMode("INVALID"), + ) + require.Error(t, err) + }) + + t.Run("should publish message to kinesis", func(t *testing.T) { + streamARN, err := createStream(client, testEvent.Type) + require.NoError(t, err) + defer deleteStream(client, testEvent.Type) + + pub, err := kinesis.New(client) + require.NoError(t, err) + pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.NoError(t, err) + events, err := readStream(client, streamARN) + require.NoError(t, err) + require.Len(t, events, 1) + require.Equal(t, events[0], testEvent.EventBytes) + }) + t.Run("stream auto creation", func(t *testing.T) { + t.Run("should create the stream if it doesn't exist and autocreate is set to true", func(t *testing.T) { + pub, err := kinesis.New(client, kinesis.WithStreamAutocreate(true)) + require.NoError(t, err) + + err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.NoError(t, err) + deleteStream(client, testEvent.Type) + }) + t.Run("should create the stream with mode = ON_DEMAND (default)", func(t *testing.T) { + pub, err := kinesis.New(client, kinesis.WithStreamAutocreate(true)) + require.NoError(t, err) + err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.NoError(t, err) + defer deleteStream(client, testEvent.Type) + + mode, err := getStreamMode(client, testEvent.Type) + require.NoError(t, err) + require.Equal(t, mode, types.StreamModeOnDemand) + }) + t.Run("should create the stream with mode = PROVISIONED", func(t *testing.T) { + pub, err := kinesis.New( + client, + kinesis.WithStreamAutocreate(true), + kinesis.WithStreamMode(types.StreamModeProvisioned), + ) + require.NoError(t, err) + err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.NoError(t, err) + defer deleteStream(client, testEvent.Type) + + mode, err := getStreamMode(client, testEvent.Type) + require.NoError(t, err) + require.Equal(t, mode, types.StreamModeProvisioned) + }) + t.Run("should create stream with specified number of shards", func(t *testing.T) { + shards := 5 + pub, err := kinesis.New( + client, + kinesis.WithStreamAutocreate(true), + kinesis.WithShards(uint32(shards)), + ) + require.NoError(t, err) + + err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.NoError(t, err) + defer deleteStream(client, testEvent.Type) + + stream, err := client.DescribeStream( + context.Background(), + &kinesis_sdk.DescribeStreamInput{ + StreamName: aws.String(testEvent.Type), + }, + ) + require.NoError(t, err) + require.Equal(t, shards, len(stream.StreamDescription.Shards)) + }) + }) + + t.Run("should publish message according to the stream pattern", func(t *testing.T) { + streamPattern := "pre-%s-post" + destinationStream := "pre-click-post" + _, err := createStream(client, destinationStream) + require.NoError(t, err) + defer deleteStream(client, destinationStream) + pub, err := kinesis.New( + client, + kinesis.WithStreamPattern(streamPattern), + ) + require.NoError(t, err) + err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.NoError(t, err) + }) + t.Run("should publish messages to static stream names", func(t *testing.T) { + destinationStream := "static" + _, err := createStream(client, destinationStream) + require.NoError(t, err) + defer deleteStream(client, destinationStream) + pub, err := kinesis.New( + client, + kinesis.WithStreamPattern(destinationStream), + ) + require.NoError(t, err) + err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") + require.NoError(t, err) + }) +} + +func TestMain(m *testing.M) { + logger.SetOutput(io.Discard) + os.Exit(m.Run()) +} diff --git a/publisher/kinesis/kinesis_test.go b/publisher/kinesis/kinesis_test.go index 2a9b98db..a6649dca 100644 --- a/publisher/kinesis/kinesis_test.go +++ b/publisher/kinesis/kinesis_test.go @@ -1,293 +1,48 @@ -package kinesis_test +package kinesis import ( - "context" - "errors" "fmt" - "io" - "os" - "strings" "testing" - "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - kinesis_sdk "github.com/aws/aws-sdk-go-v2/service/kinesis" - "github.com/aws/aws-sdk-go-v2/service/kinesis/types" - "github.com/raystack/raccoon/logger" + "github.com/aws/aws-sdk-go-v2/service/kinesis" pb "github.com/raystack/raccoon/proto" - "github.com/raystack/raccoon/publisher/kinesis" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) -const ( - envLocalstackHost = "LOCALSTACK_HOST" -) - -type localstackProvider struct{} - -func (p *localstackProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { - return aws.Credentials{ - AccessKeyID: "test", - SecretAccessKey: "test", - }, nil -} - -func withLocalStack(host string) func(o *kinesis_sdk.Options) { - return func(o *kinesis_sdk.Options) { - o.BaseEndpoint = aws.String(host) - o.Credentials = &localstackProvider{} - } -} - -var ( - testEvent = &pb.Event{ - EventBytes: []byte("EVENT"), - Type: "click", - } -) - -func createStream(client *kinesis_sdk.Client, name string) (string, error) { - _, err := client.CreateStream( - context.Background(), - &kinesis_sdk.CreateStreamInput{ - StreamName: aws.String(name), - StreamModeDetails: &types.StreamModeDetails{ - StreamMode: types.StreamModeOnDemand, +func TestKinesisProducer_UnitTest(t *testing.T) { + t.Run("should return an error if stream creation fails", func(t *testing.T) { + events := []*pb.Event{ + { + Type: "unknown", }, - ShardCount: aws.Int32(1), - }, - ) - if err != nil { - return "", err - } - retries := 5 - for range retries { - stream, err := client.DescribeStreamSummary( - context.Background(), - &kinesis_sdk.DescribeStreamSummaryInput{ - StreamName: aws.String(name), - }, - ) - if err != nil { - return "", err } - if stream.StreamDescriptionSummary.StreamStatus == types.StreamStatusActive { - return *stream.StreamDescriptionSummary.StreamARN, nil - } - time.Sleep(time.Second / 2) - } - return "", fmt.Errorf("timed out waiting for stream to get ready") -} - -func deleteStream(client *kinesis_sdk.Client, name string) error { - _, err := client.DeleteStream(context.Background(), &kinesis_sdk.DeleteStreamInput{ - StreamName: aws.String(name), - }) - if err != nil { - return err - } + client := &mockKinesisClient{} - var errNotFound *types.ResourceNotFoundException - for !errors.As(err, &errNotFound) { - _, err = client.DescribeStreamSummary( - context.Background(), - &kinesis_sdk.DescribeStreamSummaryInput{ - StreamName: aws.String(name), + client.On( + "DescribeStreamSummary", + mock.Anything, + &kinesis.DescribeStreamSummaryInput{ + StreamName: aws.String("unknown"), }, + mock.Anything, + ).Return( + &kinesis.DescribeStreamSummaryOutput{}, + fmt.Errorf("simulated error"), + ).Once() + + p, err := New( + nil, // we will override it later + WithStreamAutocreate(true), ) - time.Sleep(time.Second / 2) - } - - return nil -} - -func getStreamMode(client *kinesis_sdk.Client, name string) (types.StreamMode, error) { - stream, err := client.DescribeStreamSummary( - context.Background(), - &kinesis_sdk.DescribeStreamSummaryInput{ - StreamName: aws.String(name), - }, - ) - if err != nil { - return "", err - } - return stream.StreamDescriptionSummary.StreamModeDetails.StreamMode, nil -} - -func readStream(client *kinesis_sdk.Client, arn string) ([][]byte, error) { - stream, err := client.DescribeStream( - context.Background(), - &kinesis_sdk.DescribeStreamInput{ - StreamARN: aws.String(arn), - }, - ) - if err != nil { - return nil, err - } - if len(stream.StreamDescription.Shards) == 0 { - return nil, fmt.Errorf("stream %q has no shards", arn) - } - iter, err := client.GetShardIterator( - context.Background(), - &kinesis_sdk.GetShardIteratorInput{ - ShardId: stream.StreamDescription.Shards[0].ShardId, - StreamARN: aws.String(arn), - ShardIteratorType: types.ShardIteratorTypeTrimHorizon, - }, - ) - if err != nil { - return nil, err - } - res, err := client.GetRecords( - context.Background(), - &kinesis_sdk.GetRecordsInput{ - StreamARN: aws.String(arn), - ShardIterator: iter.ShardIterator, - }, - ) - if err != nil { - return nil, err - } - if len(res.Records) == 0 { - return nil, fmt.Errorf("got empty response") - } - rv := [][]byte{} - for _, record := range res.Records { - rv = append(rv, record.Data) - } - return rv, nil -} - -func TestKinesisProducer(t *testing.T) { - localstackHost := os.Getenv(envLocalstackHost) - if strings.TrimSpace(localstackHost) == "" { - t.Errorf("cannot run tests because %s env variable is not set", envLocalstackHost) - return - } - cfg, err := config.LoadDefaultConfig(context.Background()) - require.NoError(t, err, "error loading aws config") - - client := kinesis_sdk.NewFromConfig(cfg, withLocalStack(localstackHost)) - - t.Run("should return an error if stream doesn't exist", func(t *testing.T) { - pub, err := kinesis.New(client) - require.NoError(t, err) - err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.Error(t, err) - }) - - t.Run("should return an error if an invalid stream mode is specified", func(t *testing.T) { - _, err := kinesis.New( - client, - kinesis.WithStreamMode("INVALID"), - ) - require.Error(t, err) - }) - - t.Run("should publish message to kinesis", func(t *testing.T) { - streamARN, err := createStream(client, testEvent.Type) - require.NoError(t, err) - defer deleteStream(client, testEvent.Type) - - pub, err := kinesis.New(client) - require.NoError(t, err) - pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.NoError(t, err) - events, err := readStream(client, streamARN) - require.NoError(t, err) - require.Len(t, events, 1) - require.Equal(t, events[0], testEvent.EventBytes) - }) - t.Run("stream auto creation", func(t *testing.T) { - t.Run("should create the stream if it doesn't exist and autocreate is set to true", func(t *testing.T) { - pub, err := kinesis.New(client, kinesis.WithStreamAutocreate(true)) - require.NoError(t, err) - - err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.NoError(t, err) - deleteStream(client, testEvent.Type) - }) - t.Run("should create the stream with mode = ON_DEMAND (default)", func(t *testing.T) { - pub, err := kinesis.New(client, kinesis.WithStreamAutocreate(true)) - require.NoError(t, err) - err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.NoError(t, err) - defer deleteStream(client, testEvent.Type) - - mode, err := getStreamMode(client, testEvent.Type) - require.NoError(t, err) - require.Equal(t, mode, types.StreamModeOnDemand) - }) - t.Run("should create the stream with mode = PROVISIONED", func(t *testing.T) { - pub, err := kinesis.New( - client, - kinesis.WithStreamAutocreate(true), - kinesis.WithStreamMode(types.StreamModeProvisioned), - ) - require.NoError(t, err) - err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.NoError(t, err) - defer deleteStream(client, testEvent.Type) - - mode, err := getStreamMode(client, testEvent.Type) - require.NoError(t, err) - require.Equal(t, mode, types.StreamModeProvisioned) - }) - t.Run("should create stream with specified number of shards", func(t *testing.T) { - shards := 5 - pub, err := kinesis.New( - client, - kinesis.WithStreamAutocreate(true), - kinesis.WithShards(uint32(shards)), - ) - require.NoError(t, err) - - err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.NoError(t, err) - defer deleteStream(client, testEvent.Type) + if err != nil { + t.Errorf("error constructing client: %v", err) + return + } + p.client = client - stream, err := client.DescribeStream( - context.Background(), - &kinesis_sdk.DescribeStreamInput{ - StreamName: aws.String(testEvent.Type), - }, - ) - require.NoError(t, err) - require.Equal(t, shards, len(stream.StreamDescription.Shards)) - }) + err = p.ProduceBulk(events, "") + assert.NotNil(t, err) }) - - t.Run("should publish message according to the stream pattern", func(t *testing.T) { - streamPattern := "pre-%s-post" - destinationStream := "pre-click-post" - _, err := createStream(client, destinationStream) - require.NoError(t, err) - defer deleteStream(client, destinationStream) - pub, err := kinesis.New( - client, - kinesis.WithStreamPattern(streamPattern), - ) - require.NoError(t, err) - err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.NoError(t, err) - }) - t.Run("should publish messages to static stream names", func(t *testing.T) { - destinationStream := "static" - _, err := createStream(client, destinationStream) - require.NoError(t, err) - defer deleteStream(client, destinationStream) - pub, err := kinesis.New( - client, - kinesis.WithStreamPattern(destinationStream), - ) - require.NoError(t, err) - err = pub.ProduceBulk([]*pb.Event{testEvent}, "conn_group") - require.NoError(t, err) - }) -} - -func TestMain(m *testing.M) { - logger.SetOutput(io.Discard) - os.Exit(m.Run()) } diff --git a/publisher/kinesis/mock.go b/publisher/kinesis/mock.go new file mode 100644 index 00000000..e1b6951f --- /dev/null +++ b/publisher/kinesis/mock.go @@ -0,0 +1,27 @@ +package kinesis + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/stretchr/testify/mock" +) + +type mockKinesisClient struct { + mock.Mock +} + +func (cli *mockKinesisClient) PutRecord(ctx context.Context, in *kinesis.PutRecordInput, opts ...func(*kinesis.Options)) (*kinesis.PutRecordOutput, error) { + args := cli.Called(ctx, in, opts) + return args.Get(0).(*kinesis.PutRecordOutput), args.Error(1) +} + +func (cli *mockKinesisClient) DescribeStreamSummary(ctx context.Context, in *kinesis.DescribeStreamSummaryInput, opts ...func(*kinesis.Options)) (*kinesis.DescribeStreamSummaryOutput, error) { + args := cli.Called(ctx, in, opts) + return args.Get(0).(*kinesis.DescribeStreamSummaryOutput), args.Error(1) +} + +func (cli *mockKinesisClient) CreateStream(ctx context.Context, in *kinesis.CreateStreamInput, opts ...func(*kinesis.Options)) (*kinesis.CreateStreamOutput, error) { + args := cli.Called(ctx, in, opts) + return args.Get(0).(*kinesis.CreateStreamOutput), args.Error(1) +}