Skip to content

Commit

Permalink
Add missing null checks to AST methods (#7009)
Browse files Browse the repository at this point in the history
AST collection nodes can contain nil values, and we currently don't
check for them everywhere we should. Fix this. A more comprehensive
solution would involve parsing nil into its own Node type, so we can
avoid the checks.
  • Loading branch information
swiatekm authored Feb 25, 2025
1 parent 394fe1d commit bb2191a
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 4 deletions.
32 changes: 32 additions & 0 deletions changelog/fragments/1740485771-ast-null-checks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Kind can be one of:
# - breaking-change: a change to previously-documented behavior
# - deprecation: functionality that is being removed in a later release
# - bug-fix: fixes a problem in a previous version
# - enhancement: extends functionality but does not break or fix existing behavior
# - feature: new functionality
# - known-issue: problems that we are aware of in a given version
# - security: impacts on the security of a product or a user’s deployment.
# - upgrade: important information for someone upgrading from a prior version
# - other: does not fit into any of the other categories
kind: bug-fix

# Change summary; a 80ish characters long description of the change.
summary: Add missing null checks to AST methods

# Long description; in case the summary is not enough to describe the change
# this field accommodate a description without length limits.
# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment.
#description:

# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc.
component: elastic-agent

# PR URL; optional; the PR number that added the changeset.
# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added.
# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number.
# Please provide it if you are adding a fragment for a different PR.
#pr: https://github.com/owner/repo/1234

# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of).
# If not present is automatically filled by the tooling with the issue linked to the PR number.
issue: https://github.com/elastic/elastic-agent/issues/6999
47 changes: 43 additions & 4 deletions internal/pkg/agent/transpiler/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ func NewDictWithProcessors(nodes []Node, processors Processors) *Dict {
// Find takes a string which is a key and try to find the elements in the associated K/V.
func (d *Dict) Find(key string) (Node, bool) {
for _, i := range d.value {
if i == nil {
continue
}
if i.(*Key).name == key {
return i, true
}
Expand All @@ -119,9 +122,12 @@ func (d *Dict) Insert(node Node) {

func (d *Dict) String() string {
var sb strings.Builder
for i := 0; i < len(d.value); i++ {
for i, node := range d.value {
if node == nil {
continue
}
sb.WriteString("{")
sb.WriteString(d.value[i].String())
sb.WriteString(node.String())
sb.WriteString("}")
if i < len(d.value)-1 {
sb.WriteString(",")
Expand Down Expand Up @@ -166,6 +172,9 @@ func (d *Dict) ShallowClone() Node {
func (d *Dict) Hash() []byte {
h := sha256.New()
for _, v := range d.value {
if v == nil {
continue
}
h.Write(v.Hash())
}
return h.Sum(nil)
Expand All @@ -174,6 +183,9 @@ func (d *Dict) Hash() []byte {
// Hash64With recursively computes the given hash for the Node and its children
func (d *Dict) Hash64With(h *xxhash.Digest) error {
for _, v := range d.value {
if v == nil {
continue
}
if err := v.Hash64With(h); err != nil {
return err
}
Expand All @@ -184,6 +196,9 @@ func (d *Dict) Hash64With(h *xxhash.Digest) error {
// Vars returns a list of all variables referenced in the dictionary.
func (d *Dict) Vars(vars []string, defaultProvider string) []string {
for _, v := range d.value {
if v == nil {
continue
}
k := v.(*Key)
vars = k.Vars(vars, defaultProvider)
}
Expand All @@ -194,6 +209,9 @@ func (d *Dict) Vars(vars []string, defaultProvider string) []string {
func (d *Dict) Apply(vars *Vars) (Node, error) {
nodes := make([]Node, 0, len(d.value))
for _, v := range d.value {
if v == nil {
continue
}
k := v.(*Key)
n, err := k.Apply(vars)
if err != nil {
Expand Down Expand Up @@ -222,6 +240,9 @@ func (d *Dict) Processors() Processors {
return d.processors
}
for _, v := range d.value {
if v == nil {
continue
}
if p := v.Processors(); p != nil {
return p
}
Expand Down Expand Up @@ -387,8 +408,11 @@ func NewListWithProcessors(nodes []Node, processors Processors) *List {
func (l *List) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := 0; i < len(l.value); i++ {
sb.WriteString(l.value[i].String())
for i, v := range l.value {
if v == nil {
continue
}
sb.WriteString(v.String())
if i < len(l.value)-1 {
sb.WriteString(",")
}
Expand All @@ -401,6 +425,9 @@ func (l *List) String() string {
func (l *List) Hash() []byte {
h := sha256.New()
for _, v := range l.value {
if v == nil {
continue
}
h.Write(v.Hash())
}

Expand All @@ -410,6 +437,9 @@ func (l *List) Hash() []byte {
// Hash64With recursively computes the given hash for the Node and its children
func (l *List) Hash64With(h *xxhash.Digest) error {
for _, v := range l.value {
if v == nil {
continue
}
if err := v.Hash64With(h); err != nil {
return err
}
Expand Down Expand Up @@ -465,6 +495,9 @@ func (l *List) ShallowClone() Node {
// Vars returns a list of all variables referenced in the list.
func (l *List) Vars(vars []string, defaultProvider string) []string {
for _, v := range l.value {
if v == nil {
continue
}
vars = v.Vars(vars, defaultProvider)
}
return vars
Expand All @@ -474,6 +507,9 @@ func (l *List) Vars(vars []string, defaultProvider string) []string {
func (l *List) Apply(vars *Vars) (Node, error) {
nodes := make([]Node, 0, len(l.value))
for _, v := range l.value {
if v == nil {
continue
}
n, err := v.Apply(vars)
if err != nil {
return nil, err
Expand All @@ -492,6 +528,9 @@ func (l *List) Processors() Processors {
return l.processors
}
for _, v := range l.value {
if v == nil {
continue
}
if p := v.Processors(); p != nil {
return p
}
Expand Down
45 changes: 45 additions & 0 deletions internal/pkg/agent/transpiler/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"reflect"
"testing"

"github.com/cespare/xxhash/v2"

"github.com/elastic/elastic-agent-libs/mapstr"

"github.com/elastic/elastic-agent/internal/pkg/eql"
Expand Down Expand Up @@ -1206,6 +1208,49 @@ func TestCondition(t *testing.T) {
assert.Nil(t, input2.condition)
}

// check that all the methods handle nil values correctly
func TestNullValues(t *testing.T) {
cfgMap := map[string]any{
"inputs": map[string]any{
"dict": map[string]any{
"key": nil,
},
"list": []any{nil},
},
}
ast, err := NewAST(cfgMap)
require.NoError(t, err)
inputs, ok := Lookup(ast, "inputs")
require.True(t, ok)

assert.NotEmpty(t, inputs.String())

node, ok := inputs.Find("dict")
assert.True(t, ok)
assert.NotNil(t, node)

assert.NotNil(t, inputs.Value())

assert.NotNil(t, inputs.Clone())

assert.NotNil(t, inputs.ShallowClone())

assert.NotEmpty(t, inputs.Hash())

h := xxhash.New()
err = inputs.Hash64With(h)
assert.NoError(t, err)
assert.NotEmpty(t, h.Sum64())

assert.Empty(t, inputs.Vars([]string{}, "default"))

newNode, err := inputs.Apply(nil)
assert.NoError(t, err)
assert.NotNil(t, newNode)

assert.Empty(t, inputs.Processors())
}

func mustMakeVars(mapping map[string]interface{}) *Vars {
v, err := NewVars("", mapping, nil, "")
if err != nil {
Expand Down

0 comments on commit bb2191a

Please sign in to comment.