diff --git a/pkg/storage/unified/resource/broadcaster_test.go b/pkg/storage/unified/resource/broadcaster_test.go index 8eedfa01ce564..3ae056a8a772a 100644 --- a/pkg/storage/unified/resource/broadcaster_test.go +++ b/pkg/storage/unified/resource/broadcaster_test.go @@ -104,3 +104,41 @@ func TestCache(t *testing.T) { // slice should return all values require.Equal(t, []int{4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, c.Slice()) } + +func TestBroadcaster(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan int) + input := []int{1, 2, 3} + go func() { + for _, v := range input { + ch <- v + } + }() + t.Cleanup(func() { + close(ch) + }) + + b, err := NewBroadcaster(ctx, func(out chan<- int) error { + go func() { + for v := range ch { + out <- v + } + }() + return nil + }) + require.NoError(t, err) + + sub, err := b.Subscribe(ctx) + require.NoError(t, err) + + for _, expected := range input { + v, ok := <-sub + require.True(t, ok) + require.Equal(t, expected, v) + } + + // cancel the context should close the stream + cancel() + _, ok := <-sub + require.False(t, ok) +} diff --git a/pkg/storage/unified/resource/server.go b/pkg/storage/unified/resource/server.go index 7853f9125dbed..148de06595a71 100644 --- a/pkg/storage/unified/resource/server.go +++ b/pkg/storage/unified/resource/server.go @@ -919,10 +919,7 @@ func (s *server) initWatcher() error { return err } go func() { - for { - // pipe all events - v := <-events - + for v := range events { if v == nil { s.log.Error("received nil event") continue diff --git a/pkg/storage/unified/sql/notifier_sql.go b/pkg/storage/unified/sql/notifier_sql.go index a4d0a008fccfa..9940f7c882f4a 100644 --- a/pkg/storage/unified/sql/notifier_sql.go +++ b/pkg/storage/unified/sql/notifier_sql.go @@ -120,6 +120,8 @@ func (p *pollingNotifier) poller(ctx context.Context, since groupResourceRV, str for { select { + case <-ctx.Done(): + return case <-p.done: return case <-t.C: diff --git a/pkg/storage/unified/sql/notifier_sql_test.go b/pkg/storage/unified/sql/notifier_sql_test.go index 693d63ea4851e..07fcfac3365cb 100644 --- a/pkg/storage/unified/sql/notifier_sql_test.go +++ b/pkg/storage/unified/sql/notifier_sql_test.go @@ -357,4 +357,41 @@ func TestPollingNotifier(t *testing.T) { t.Fatal("timeout waiting for events channel to close") } }) + + t.Run("stops polling when context is cancelled", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + cfg := &pollingNotifierConfig{ + dialect: sqltemplate.SQLite, + pollingInterval: 10 * time.Millisecond, + watchBufferSize: 10, + log: log.NewNopLogger(), + tracer: noop.NewTracerProvider().Tracer("test"), + batchLock: &batchLock{}, + listLatestRVs: func(ctx context.Context) (groupResourceRV, error) { return nil, nil }, + historyPoll: func(ctx context.Context, grp string, res string, since int64) ([]*historyPollResponse, error) { + return nil, nil + }, + done: make(chan struct{}), + } + + notifier, err := newPollingNotifier(cfg) + require.NoError(t, err) + require.NotNil(t, notifier) + + events, err := notifier.notify(ctx) + require.NoError(t, err) + require.NotNil(t, events) + + cancel() + + select { + case _, ok := <-events: + require.False(t, ok, "events channel should be closed") + case <-time.After(time.Second): + t.Fatal("timeout waiting for events channel to close") + } + }) } diff --git a/pkg/storage/unified/testing/storage_backend.go b/pkg/storage/unified/testing/storage_backend.go index d34cf3c0c71bf..d19f9d2b05a66 100644 --- a/pkg/storage/unified/testing/storage_backend.go +++ b/pkg/storage/unified/testing/storage_backend.go @@ -51,7 +51,7 @@ func RunStorageBackendTest(t *testing.T, newBackend NewBackendFunc, opts *TestOp fn func(*testing.T, resource.StorageBackend) }{ {TestHappyPath, runTestIntegrationBackendHappyPath}, - {TestWatchWriteEvents, runTestIntegrationBackendWatchWriteEventsFromLastest}, + {TestWatchWriteEvents, runTestIntegrationBackendWatchWriteEvents}, {TestList, runTestIntegrationBackendList}, {TestBlobSupport, runTestIntegrationBlobSupport}, {TestGetResourceStats, runTestIntegrationBackendGetResourceStats}, @@ -272,7 +272,7 @@ func runTestIntegrationBackendGetResourceStats(t *testing.T, backend resource.St }) } -func runTestIntegrationBackendWatchWriteEventsFromLastest(t *testing.T, backend resource.StorageBackend) { +func runTestIntegrationBackendWatchWriteEvents(t *testing.T, backend resource.StorageBackend) { ctx := testutil.NewTestContext(t, time.Now().Add(5*time.Second)) // Create a few resources before initing the watch @@ -287,6 +287,12 @@ func runTestIntegrationBackendWatchWriteEventsFromLastest(t *testing.T, backend _, err = writeEvent(ctx, backend, "item2", resource.WatchEvent_ADDED) require.NoError(t, err) require.Equal(t, "item2", (<-stream).Key.Name) + + // Should close the stream + ctx.Cancel() + + _, ok := <-stream + require.False(t, ok) } func runTestIntegrationBackendList(t *testing.T, backend resource.StorageBackend) {