Skip to content

Commit

Permalink
kwild: re-add extension config flags
Browse files Browse the repository at this point in the history
  • Loading branch information
brennanjl authored Feb 12, 2025
1 parent ba185ec commit 9c0f439
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 1 deletion.
62 changes: 61 additions & 1 deletion app/node/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package node

import (
"fmt"
"strings"

"github.com/spf13/cobra"

Expand All @@ -27,14 +28,22 @@ func StartCmd() *cobra.Command {
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Args: cobra.NoArgs,
Version: version.KwilVersion,
Example: custom.BinaryConfig.NodeCmd + " start -r .testnet",
RunE: func(cmd *cobra.Command, args []string) error {
rootDir := conf.RootDir()

extConfs, err := parseExtensionFlags(args)
if err != nil {
return err
}

cfg := conf.ActiveConfig()

// we don't need to worry about order of priority with applying the extension
// flag configs because flags are always highest priority
cfg.Extensions = extConfs

bind.Debugf("effective node config (toml):\n%s", bind.LazyPrinter(func() string {
rawToml, err := cfg.ToTOML()
if err != nil {
Expand Down Expand Up @@ -72,3 +81,54 @@ func StartCmd() *cobra.Command {

return cmd
}

// parseExtensionFlags parses the extension flags from the command line and
// returns a map of extension names to their configured values
func parseExtensionFlags(args []string) (map[string]map[string]string, error) {
exts := make(map[string]map[string]string)
for i := 0; i < len(args); i++ {
if !strings.HasPrefix(args[i], "--extension.") {
return nil, fmt.Errorf("expected extension flag, got %q", args[i])
}
// split the flag into the extension name and the flag name
// we intentionally do not use SplitN because we want to verify
// there are exactly 3 parts.
parts := strings.Split(args[i], ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid extension flag %q", args[i])
}

extName := parts[1]

// get the extension map for the extension name.
// if it doesn't exist, create it.
ext, ok := exts[extName]
if !ok {
ext = make(map[string]string)
exts[extName] = ext
}

// we now need to get the flag value. Flags can be passed
// as either "--extension.extname.flagname value" or
// "--extension.extname.flagname=value"
if strings.Contains(parts[2], "=") {
// flag value is in the same argument
val := strings.SplitN(parts[2], "=", 2)
ext[val[0]] = val[1]
} else {
// flag value is in the next argument
if i+1 >= len(args) {
return nil, fmt.Errorf("missing value for extension flag %q", args[i])
}

if strings.HasPrefix(args[i+1], "--") {
return nil, fmt.Errorf("missing value for extension flag %q", args[i])
}

ext[parts[2]] = args[i+1]
i++
}
}

return exts, nil
}
72 changes: 72 additions & 0 deletions app/node/start_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package node

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_ExtensionFlags(t *testing.T) {
type testcase struct {
name string
flagset []string
want map[string]map[string]string
wantErr bool
}

tests := []testcase{
{
name: "empty flagset",
flagset: []string{},
want: map[string]map[string]string{},
},
{
name: "single flag",
flagset: []string{"--extension.extname.flagname", "value"},
want: map[string]map[string]string{
"extname": {
"flagname": "value",
},
},
},
{
name: "multiple flags",
flagset: []string{"--extension.extname.flagname", "value", "--extension.extname2.flagname2=value2"},
want: map[string]map[string]string{
"extname": {
"flagname": "value",
},
"extname2": {
"flagname2": "value2",
},
},
},
{
name: "missing value",
flagset: []string{
"--extension.extname.flagname",
},
wantErr: true,
},
{
name: "pass flag as a value errors",
flagset: []string{
"--extension.extname.flagname", "--extension.extname2.flagname2=value2",
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseExtensionFlags(tt.flagset)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)

require.EqualValues(t, tt.want, got)
})
}
}

0 comments on commit 9c0f439

Please sign in to comment.