From 337a42a4995e5e45a5980eea749d8a179c0ff069 Mon Sep 17 00:00:00 2001 From: Blake Rouse Date: Thu, 27 Feb 2025 23:07:35 -0500 Subject: [PATCH] Fix panic on fetch provider initialization (#6958) --- internal/pkg/composable/context.go | 6 +- internal/pkg/composable/controller.go | 58 ++++++------- internal/pkg/composable/controller_test.go | 95 ++++++++++++++++++++++ internal/pkg/composable/dynamic.go | 6 +- internal/pkg/composable/registry.go | 19 +++-- 5 files changed, 139 insertions(+), 45 deletions(-) diff --git a/internal/pkg/composable/context.go b/internal/pkg/composable/context.go index f358e0bfe32..0eadeb9f5fb 100644 --- a/internal/pkg/composable/context.go +++ b/internal/pkg/composable/context.go @@ -17,7 +17,7 @@ import ( type ContextProviderBuilder func(log *logger.Logger, config *config.Config, managed bool) (corecomp.ContextProvider, error) // MustAddContextProvider adds a new ContextProviderBuilder and panics if it AddContextProvider returns an error. -func (r *providerRegistry) MustAddContextProvider(name string, builder ContextProviderBuilder) { +func (r *ProviderRegistry) MustAddContextProvider(name string, builder ContextProviderBuilder) { err := r.AddContextProvider(name, builder) if err != nil { panic(err) @@ -27,7 +27,7 @@ func (r *providerRegistry) MustAddContextProvider(name string, builder ContextPr // AddContextProvider adds a new ContextProviderBuilder // //nolint:dupl,goimports,nolintlint // false positive -func (r *providerRegistry) AddContextProvider(name string, builder ContextProviderBuilder) error { +func (r *ProviderRegistry) AddContextProvider(name string, builder ContextProviderBuilder) error { r.lock.Lock() defer r.lock.Unlock() @@ -55,7 +55,7 @@ func (r *providerRegistry) AddContextProvider(name string, builder ContextProvid } // GetContextProvider returns the context provider with the giving name, nil if it doesn't exist -func (r *providerRegistry) GetContextProvider(name string) (ContextProviderBuilder, bool) { +func (r *ProviderRegistry) GetContextProvider(name string) (ContextProviderBuilder, bool) { r.lock.RLock() defer r.lock.RUnlock() diff --git a/internal/pkg/composable/controller.go b/internal/pkg/composable/controller.go index 453a3d47fb7..61402d7f855 100644 --- a/internal/pkg/composable/controller.go +++ b/internal/pkg/composable/controller.go @@ -80,8 +80,13 @@ type controller struct { dynamicProviderStates map[string]*dynamicProviderState } -// New creates a new controller. +// New creates a new controller with the global set of providers. func New(log *logger.Logger, c *config.Config, managed bool) (Controller, error) { + return NewWithProviders(log, c, managed, Providers) +} + +// NewWithProviders creates a new controller with the given set of providers. +func NewWithProviders(log *logger.Logger, c *config.Config, managed bool, providers *ProviderRegistry) (Controller, error) { l := log.Named("composable") var providersCfg Config @@ -110,7 +115,7 @@ func New(log *logger.Logger, c *config.Config, managed bool) (Controller, error) // build all the context providers contextProviders := map[string]contextProvider{} - for name, builder := range Providers.contextProviders { + for name, builder := range providers.contextProviders { pCfg, ok := providersCfg.Providers[name] if (ok && !pCfg.Enabled()) || (!ok && !providersInitialDefault) { // explicitly disabled; skipping @@ -124,7 +129,7 @@ func New(log *logger.Logger, c *config.Config, managed bool) (Controller, error) // build all the dynamic providers dynamicProviders := map[string]dynamicProvider{} - for name, builder := range Providers.dynamicProviders { + for name, builder := range providers.dynamicProviders { pCfg, ok := providersCfg.Providers[name] if (ok && !pCfg.Enabled()) || (!ok && !providersInitialDefault) { // explicitly disabled; skipping @@ -187,45 +192,33 @@ func (c *controller) Run(ctx context.Context) error { wg.Wait() }() - // synchronize the fetch providers through a channel - var fetchProvidersLock sync.RWMutex - var fetchProviders mapstr.M - fetchCh := make(chan fetchProvider) - go func() { - for { - select { - case <-localCtx.Done(): - return - case msg := <-fetchCh: - fetchProvidersLock.Lock() - if msg.fetchProvider == nil { - _ = fetchProviders.Delete(msg.name) - } else { - _, _ = fetchProviders.Put(msg.name, msg.fetchProvider) - } - fetchProvidersLock.Unlock() - } - } - }() - - // send initial vars state - fetchProvidersLock.RLock() + // send initial vars state (empty fetch providers initially) + fetchProviders := mapstr.M{} err := c.sendVars(ctx, nil, fetchProviders) if err != nil { - fetchProvidersLock.RUnlock() // only error is context cancel, no need to add error message context return err } - fetchProvidersLock.RUnlock() // performs debounce of notifies; accumulates them into 100 millisecond chunks var observedResult chan []*transpiler.Vars + fetchCh := make(chan fetchProvider) for { DEBOUNCE: for { select { case <-ctx.Done(): return ctx.Err() + case msg := <-fetchCh: + if msg.fetchProvider == nil { + _ = fetchProviders.Delete(msg.name) + } else { + _, _ = fetchProviders.Put(msg.name, msg.fetchProvider) + } + t.Reset(100 * time.Millisecond) + c.logger.Debugf("Fetch providers state changed for composable inputs; debounce started") + drainChan(stateChangedChan) // state change trigger (no need for signal to be handled) + break DEBOUNCE case observed := <-c.observedCh: // observedResult holds the channel to send the latest observed results on // if nothing is changed then nil will be sent over the channel if the set of running @@ -235,7 +228,7 @@ func (c *controller) Run(ctx context.Context) error { if changed { t.Reset(100 * time.Millisecond) c.logger.Debugf("Observed state changed for composable inputs; debounce started") - drainChan(stateChangedChan) + drainChan(stateChangedChan) // state change trigger (no need for signal to be handled) break DEBOUNCE } else { // nothing changed send nil to alert the caller @@ -261,15 +254,12 @@ func (c *controller) Run(ctx context.Context) error { } // send the vars to the watcher or the observer caller - fetchProvidersLock.RLock() err := c.sendVars(ctx, observedResult, fetchProviders) observedResult = nil if err != nil { - fetchProvidersLock.RUnlock() // only error is context cancel, no need to add error message context return err } - fetchProvidersLock.RUnlock() } } @@ -550,6 +540,10 @@ func (c *controller) startDynamicProvider(ctx context.Context, wg *sync.WaitGrou } func (c *controller) generateVars(fetchContextProviders mapstr.M, defaultProvider string) []*transpiler.Vars { + // copy fetch providers map so they cannot change in the context + // of the currently processed variables + fetchContextProviders = fetchContextProviders.Clone() + // build the vars list of mappings vars := make([]*transpiler.Vars, 1) mapping, _ := transpiler.NewAST(map[string]any{}) diff --git a/internal/pkg/composable/controller_test.go b/internal/pkg/composable/controller_test.go index 5bedf879154..75df28069d4 100644 --- a/internal/pkg/composable/controller_test.go +++ b/internal/pkg/composable/controller_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "strings" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/transpiler" "github.com/elastic/elastic-agent/internal/pkg/composable" "github.com/elastic/elastic-agent/internal/pkg/config" + corecomp "github.com/elastic/elastic-agent/internal/pkg/core/composable" "github.com/elastic/elastic-agent/pkg/core/logger" _ "github.com/elastic/elastic-agent/internal/pkg/composable/providers/env" @@ -142,6 +144,81 @@ func TestController(t *testing.T) { assert.Len(t, vars3map, 0) // should be empty after empty Observe } +func TestControllerWithFetchProvider(t *testing.T) { + providers := composable.NewProviderRegistry() + providers.MustAddContextProvider("custom_fetch", func(_ *logger.Logger, _ *config.Config, _ bool) (corecomp.ContextProvider, error) { + // add a delay to ensure that even if it takes time to start the provider that it still gets placed + // as a fetch provider + <-time.After(1 * time.Second) + return &customFetchProvider{}, nil + }) + + cfg := config.New() + log, err := logger.New("", false) + require.NoError(t, err) + c, err := composable.NewWithProviders(log, cfg, false, providers) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + observed := false + setErr := make(chan error, 1) + go func() { + defer cancel() + for { + select { + case <-ctx.Done(): + return + case vars := <-c.Watch(): + if !observed { + vars, err = c.Observe(ctx, []string{"custom_fetch.vars.key1"}) + if err != nil { + setErr <- err + return + } + observed = true + } + if len(vars) > 0 { + node, err := vars[0].Replace("${custom_fetch.vars.key1}") + if err == nil { + // replace occurred so the fetch provider is now present + strNode, ok := node.(*transpiler.StrVal) + if !ok { + setErr <- fmt.Errorf("expected *transpiler.StrVal") + return + } + strVal, ok := strNode.Value().(string) + if !ok { + setErr <- fmt.Errorf("expected string") + return + } + if strVal != "vars.key1" { + setErr <- fmt.Errorf("expected replaced value error: %s != vars.key1", strVal) + return + } + // replacement worked + setErr <- nil + return + } + } + } + } + }() + + errCh := make(chan error) + go func() { + errCh <- c.Run(ctx) + }() + err = <-errCh + if errors.Is(err, context.Canceled) { + err = nil + } + require.NoError(t, err) + err = <-setErr + assert.NoError(t, err) +} + func TestProvidersDefaultDisabled(t *testing.T) { tests := []struct { name string @@ -425,3 +502,21 @@ func TestDefaultProvider(t *testing.T) { assert.Equal(t, "custom", c.DefaultProvider()) }) } + +type customFetchProvider struct{} + +func (c *customFetchProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { + <-ctx.Done() + return ctx.Err() +} + +func (c *customFetchProvider) Fetch(key string) (string, bool) { + tokens := strings.SplitN(key, ".", 2) + if len(tokens) > 0 && tokens[0] != "custom_fetch" { + return "", false + } + return tokens[1], true +} + +// validate it registers as a fetch provider +var _ corecomp.FetchContextProvider = (*customFetchProvider)(nil) diff --git a/internal/pkg/composable/dynamic.go b/internal/pkg/composable/dynamic.go index 74071dfa5dc..e54f27fdbcb 100644 --- a/internal/pkg/composable/dynamic.go +++ b/internal/pkg/composable/dynamic.go @@ -37,7 +37,7 @@ type DynamicProvider interface { type DynamicProviderBuilder func(log *logger.Logger, config *config.Config, managed bool) (DynamicProvider, error) // MustAddDynamicProvider adds a new DynamicProviderBuilder and panics if it AddDynamicProvider returns an error. -func (r *providerRegistry) MustAddDynamicProvider(name string, builder DynamicProviderBuilder) { +func (r *ProviderRegistry) MustAddDynamicProvider(name string, builder DynamicProviderBuilder) { err := r.AddDynamicProvider(name, builder) if err != nil { panic(err) @@ -47,7 +47,7 @@ func (r *providerRegistry) MustAddDynamicProvider(name string, builder DynamicPr // AddDynamicProvider adds a new DynamicProviderBuilder // //nolint:dupl,goimports,nolintlint // false positive -func (r *providerRegistry) AddDynamicProvider(providerName string, builder DynamicProviderBuilder) error { +func (r *ProviderRegistry) AddDynamicProvider(providerName string, builder DynamicProviderBuilder) error { r.lock.Lock() defer r.lock.Unlock() @@ -72,7 +72,7 @@ func (r *providerRegistry) AddDynamicProvider(providerName string, builder Dynam } // GetDynamicProvider returns the dynamic provider with the giving name, nil if it doesn't exist -func (r *providerRegistry) GetDynamicProvider(name string) (DynamicProviderBuilder, bool) { +func (r *ProviderRegistry) GetDynamicProvider(name string) (DynamicProviderBuilder, bool) { r.lock.RLock() defer r.lock.RUnlock() diff --git a/internal/pkg/composable/registry.go b/internal/pkg/composable/registry.go index 4d01b9884b5..7854ac8ded4 100644 --- a/internal/pkg/composable/registry.go +++ b/internal/pkg/composable/registry.go @@ -10,8 +10,8 @@ import ( "github.com/elastic/elastic-agent-libs/logp" ) -// providerRegistry is a registry of providers -type providerRegistry struct { +// ProviderRegistry is a registry of providers +type ProviderRegistry struct { contextProviders map[string]ContextProviderBuilder dynamicProviders map[string]DynamicProviderBuilder @@ -19,9 +19,14 @@ type providerRegistry struct { lock sync.RWMutex } -// Providers holds all known providers, they must be added to it to enable them for use -var Providers = &providerRegistry{ - contextProviders: make(map[string]ContextProviderBuilder), - dynamicProviders: make(map[string]DynamicProviderBuilder), - logger: logp.NewLogger("dynamic"), +// NewProviderRegistry creates a new provider registry. +func NewProviderRegistry() *ProviderRegistry { + return &ProviderRegistry{ + contextProviders: make(map[string]ContextProviderBuilder), + dynamicProviders: make(map[string]DynamicProviderBuilder), + logger: logp.NewLogger("composable"), + } } + +// Providers holds all known providers, they must be added to it to enable them for use +var Providers = NewProviderRegistry()