Skip to content

Commit

Permalink
Fix panic on fetch provider initialization (#6958)
Browse files Browse the repository at this point in the history
  • Loading branch information
blakerouse authored Feb 28, 2025
1 parent 4a7add4 commit 337a42a
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 45 deletions.
6 changes: 3 additions & 3 deletions internal/pkg/composable/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
type ContextProviderBuilder func(log *logger.Logger, config *config.Config, managed bool) (corecomp.ContextProvider, error)

// MustAddContextProvider adds a new ContextProviderBuilder and panics if it AddContextProvider returns an error.
func (r *providerRegistry) MustAddContextProvider(name string, builder ContextProviderBuilder) {
func (r *ProviderRegistry) MustAddContextProvider(name string, builder ContextProviderBuilder) {
err := r.AddContextProvider(name, builder)
if err != nil {
panic(err)
Expand All @@ -27,7 +27,7 @@ func (r *providerRegistry) MustAddContextProvider(name string, builder ContextPr
// AddContextProvider adds a new ContextProviderBuilder
//
//nolint:dupl,goimports,nolintlint // false positive
func (r *providerRegistry) AddContextProvider(name string, builder ContextProviderBuilder) error {
func (r *ProviderRegistry) AddContextProvider(name string, builder ContextProviderBuilder) error {
r.lock.Lock()
defer r.lock.Unlock()

Expand Down Expand Up @@ -55,7 +55,7 @@ func (r *providerRegistry) AddContextProvider(name string, builder ContextProvid
}

// GetContextProvider returns the context provider with the giving name, nil if it doesn't exist
func (r *providerRegistry) GetContextProvider(name string) (ContextProviderBuilder, bool) {
func (r *ProviderRegistry) GetContextProvider(name string) (ContextProviderBuilder, bool) {
r.lock.RLock()
defer r.lock.RUnlock()

Expand Down
58 changes: 26 additions & 32 deletions internal/pkg/composable/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ type controller struct {
dynamicProviderStates map[string]*dynamicProviderState
}

// New creates a new controller.
// New creates a new controller with the global set of providers.
func New(log *logger.Logger, c *config.Config, managed bool) (Controller, error) {
return NewWithProviders(log, c, managed, Providers)
}

// NewWithProviders creates a new controller with the given set of providers.
func NewWithProviders(log *logger.Logger, c *config.Config, managed bool, providers *ProviderRegistry) (Controller, error) {
l := log.Named("composable")

var providersCfg Config
Expand Down Expand Up @@ -110,7 +115,7 @@ func New(log *logger.Logger, c *config.Config, managed bool) (Controller, error)

// build all the context providers
contextProviders := map[string]contextProvider{}
for name, builder := range Providers.contextProviders {
for name, builder := range providers.contextProviders {
pCfg, ok := providersCfg.Providers[name]
if (ok && !pCfg.Enabled()) || (!ok && !providersInitialDefault) {
// explicitly disabled; skipping
Expand All @@ -124,7 +129,7 @@ func New(log *logger.Logger, c *config.Config, managed bool) (Controller, error)

// build all the dynamic providers
dynamicProviders := map[string]dynamicProvider{}
for name, builder := range Providers.dynamicProviders {
for name, builder := range providers.dynamicProviders {
pCfg, ok := providersCfg.Providers[name]
if (ok && !pCfg.Enabled()) || (!ok && !providersInitialDefault) {
// explicitly disabled; skipping
Expand Down Expand Up @@ -187,45 +192,33 @@ func (c *controller) Run(ctx context.Context) error {
wg.Wait()
}()

// synchronize the fetch providers through a channel
var fetchProvidersLock sync.RWMutex
var fetchProviders mapstr.M
fetchCh := make(chan fetchProvider)
go func() {
for {
select {
case <-localCtx.Done():
return
case msg := <-fetchCh:
fetchProvidersLock.Lock()
if msg.fetchProvider == nil {
_ = fetchProviders.Delete(msg.name)
} else {
_, _ = fetchProviders.Put(msg.name, msg.fetchProvider)
}
fetchProvidersLock.Unlock()
}
}
}()

// send initial vars state
fetchProvidersLock.RLock()
// send initial vars state (empty fetch providers initially)
fetchProviders := mapstr.M{}
err := c.sendVars(ctx, nil, fetchProviders)
if err != nil {
fetchProvidersLock.RUnlock()
// only error is context cancel, no need to add error message context
return err
}
fetchProvidersLock.RUnlock()

// performs debounce of notifies; accumulates them into 100 millisecond chunks
var observedResult chan []*transpiler.Vars
fetchCh := make(chan fetchProvider)
for {
DEBOUNCE:
for {
select {
case <-ctx.Done():
return ctx.Err()
case msg := <-fetchCh:
if msg.fetchProvider == nil {
_ = fetchProviders.Delete(msg.name)
} else {
_, _ = fetchProviders.Put(msg.name, msg.fetchProvider)
}
t.Reset(100 * time.Millisecond)
c.logger.Debugf("Fetch providers state changed for composable inputs; debounce started")
drainChan(stateChangedChan) // state change trigger (no need for signal to be handled)
break DEBOUNCE
case observed := <-c.observedCh:
// observedResult holds the channel to send the latest observed results on
// if nothing is changed then nil will be sent over the channel if the set of running
Expand All @@ -235,7 +228,7 @@ func (c *controller) Run(ctx context.Context) error {
if changed {
t.Reset(100 * time.Millisecond)
c.logger.Debugf("Observed state changed for composable inputs; debounce started")
drainChan(stateChangedChan)
drainChan(stateChangedChan) // state change trigger (no need for signal to be handled)
break DEBOUNCE
} else {
// nothing changed send nil to alert the caller
Expand All @@ -261,15 +254,12 @@ func (c *controller) Run(ctx context.Context) error {
}

// send the vars to the watcher or the observer caller
fetchProvidersLock.RLock()
err := c.sendVars(ctx, observedResult, fetchProviders)
observedResult = nil
if err != nil {
fetchProvidersLock.RUnlock()
// only error is context cancel, no need to add error message context
return err
}
fetchProvidersLock.RUnlock()
}
}

Expand Down Expand Up @@ -550,6 +540,10 @@ func (c *controller) startDynamicProvider(ctx context.Context, wg *sync.WaitGrou
}

func (c *controller) generateVars(fetchContextProviders mapstr.M, defaultProvider string) []*transpiler.Vars {
// copy fetch providers map so they cannot change in the context
// of the currently processed variables
fetchContextProviders = fetchContextProviders.Clone()

// build the vars list of mappings
vars := make([]*transpiler.Vars, 1)
mapping, _ := transpiler.NewAST(map[string]any{})
Expand Down
95 changes: 95 additions & 0 deletions internal/pkg/composable/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"

Expand All @@ -17,6 +18,7 @@ import (
"github.com/elastic/elastic-agent/internal/pkg/agent/transpiler"
"github.com/elastic/elastic-agent/internal/pkg/composable"
"github.com/elastic/elastic-agent/internal/pkg/config"
corecomp "github.com/elastic/elastic-agent/internal/pkg/core/composable"
"github.com/elastic/elastic-agent/pkg/core/logger"

_ "github.com/elastic/elastic-agent/internal/pkg/composable/providers/env"
Expand Down Expand Up @@ -142,6 +144,81 @@ func TestController(t *testing.T) {
assert.Len(t, vars3map, 0) // should be empty after empty Observe
}

func TestControllerWithFetchProvider(t *testing.T) {
providers := composable.NewProviderRegistry()
providers.MustAddContextProvider("custom_fetch", func(_ *logger.Logger, _ *config.Config, _ bool) (corecomp.ContextProvider, error) {
// add a delay to ensure that even if it takes time to start the provider that it still gets placed
// as a fetch provider
<-time.After(1 * time.Second)
return &customFetchProvider{}, nil
})

cfg := config.New()
log, err := logger.New("", false)
require.NoError(t, err)
c, err := composable.NewWithProviders(log, cfg, false, providers)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

observed := false
setErr := make(chan error, 1)
go func() {
defer cancel()
for {
select {
case <-ctx.Done():
return
case vars := <-c.Watch():
if !observed {
vars, err = c.Observe(ctx, []string{"custom_fetch.vars.key1"})
if err != nil {
setErr <- err
return
}
observed = true
}
if len(vars) > 0 {
node, err := vars[0].Replace("${custom_fetch.vars.key1}")
if err == nil {
// replace occurred so the fetch provider is now present
strNode, ok := node.(*transpiler.StrVal)
if !ok {
setErr <- fmt.Errorf("expected *transpiler.StrVal")
return
}
strVal, ok := strNode.Value().(string)
if !ok {
setErr <- fmt.Errorf("expected string")
return
}
if strVal != "vars.key1" {
setErr <- fmt.Errorf("expected replaced value error: %s != vars.key1", strVal)
return
}
// replacement worked
setErr <- nil
return
}
}
}
}
}()

errCh := make(chan error)
go func() {
errCh <- c.Run(ctx)
}()
err = <-errCh
if errors.Is(err, context.Canceled) {
err = nil
}
require.NoError(t, err)
err = <-setErr
assert.NoError(t, err)
}

func TestProvidersDefaultDisabled(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -425,3 +502,21 @@ func TestDefaultProvider(t *testing.T) {
assert.Equal(t, "custom", c.DefaultProvider())
})
}

type customFetchProvider struct{}

func (c *customFetchProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error {
<-ctx.Done()
return ctx.Err()
}

func (c *customFetchProvider) Fetch(key string) (string, bool) {
tokens := strings.SplitN(key, ".", 2)
if len(tokens) > 0 && tokens[0] != "custom_fetch" {
return "", false
}
return tokens[1], true
}

// validate it registers as a fetch provider
var _ corecomp.FetchContextProvider = (*customFetchProvider)(nil)
6 changes: 3 additions & 3 deletions internal/pkg/composable/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type DynamicProvider interface {
type DynamicProviderBuilder func(log *logger.Logger, config *config.Config, managed bool) (DynamicProvider, error)

// MustAddDynamicProvider adds a new DynamicProviderBuilder and panics if it AddDynamicProvider returns an error.
func (r *providerRegistry) MustAddDynamicProvider(name string, builder DynamicProviderBuilder) {
func (r *ProviderRegistry) MustAddDynamicProvider(name string, builder DynamicProviderBuilder) {
err := r.AddDynamicProvider(name, builder)
if err != nil {
panic(err)
Expand All @@ -47,7 +47,7 @@ func (r *providerRegistry) MustAddDynamicProvider(name string, builder DynamicPr
// AddDynamicProvider adds a new DynamicProviderBuilder
//
//nolint:dupl,goimports,nolintlint // false positive
func (r *providerRegistry) AddDynamicProvider(providerName string, builder DynamicProviderBuilder) error {
func (r *ProviderRegistry) AddDynamicProvider(providerName string, builder DynamicProviderBuilder) error {
r.lock.Lock()
defer r.lock.Unlock()

Expand All @@ -72,7 +72,7 @@ func (r *providerRegistry) AddDynamicProvider(providerName string, builder Dynam
}

// GetDynamicProvider returns the dynamic provider with the giving name, nil if it doesn't exist
func (r *providerRegistry) GetDynamicProvider(name string) (DynamicProviderBuilder, bool) {
func (r *ProviderRegistry) GetDynamicProvider(name string) (DynamicProviderBuilder, bool) {
r.lock.RLock()
defer r.lock.RUnlock()

Expand Down
19 changes: 12 additions & 7 deletions internal/pkg/composable/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,23 @@ import (
"github.com/elastic/elastic-agent-libs/logp"
)

// providerRegistry is a registry of providers
type providerRegistry struct {
// ProviderRegistry is a registry of providers
type ProviderRegistry struct {
contextProviders map[string]ContextProviderBuilder
dynamicProviders map[string]DynamicProviderBuilder

logger *logp.Logger
lock sync.RWMutex
}

// Providers holds all known providers, they must be added to it to enable them for use
var Providers = &providerRegistry{
contextProviders: make(map[string]ContextProviderBuilder),
dynamicProviders: make(map[string]DynamicProviderBuilder),
logger: logp.NewLogger("dynamic"),
// NewProviderRegistry creates a new provider registry.
func NewProviderRegistry() *ProviderRegistry {
return &ProviderRegistry{
contextProviders: make(map[string]ContextProviderBuilder),
dynamicProviders: make(map[string]DynamicProviderBuilder),
logger: logp.NewLogger("composable"),
}
}

// Providers holds all known providers, they must be added to it to enable them for use
var Providers = NewProviderRegistry()

0 comments on commit 337a42a

Please sign in to comment.