Skip to content

Commit

Permalink
extract supported engines
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer committed Jan 20, 2025
1 parent b00d633 commit 6328112
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion management/server/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestSaveDNSSettings(t *testing.T) {

account, err := initTestDNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Errorf("failed to init testing account: %v", err)
}

err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
Expand Down
4 changes: 1 addition & 3 deletions management/server/store/sql_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ import (
nbroute "github.com/netbirdio/netbird/route"
)

var engines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}

func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
t.Helper()
for _, engine := range engines {
for _, engine := range supportedEngines {
if os.Getenv("NETBIRD_STORE_ENGINE") != "" && os.Getenv("NETBIRD_STORE_ENGINE") != string(engine) {
continue
}
Expand Down
26 changes: 14 additions & 12 deletions management/server/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"path"
"path/filepath"
"runtime"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -181,6 +182,8 @@ const (
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)

var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}

func getStoreEngineFromEnv() Engine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
Expand All @@ -189,7 +192,7 @@ func getStoreEngineFromEnv() Engine {
}

value := Engine(strings.ToLower(kind))
if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine {
if slices.Contains(supportedEngines, value) {
return value
}

Expand Down Expand Up @@ -337,6 +340,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
}

func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
var cleanup func()
if kind == PostgresStoreEngine {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
var err error
Expand All @@ -356,24 +360,21 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
}

newDsn, cleanup, err := createRandomDB(dsn, db, kind)
dsn, cleanup, err = createRandomDB(dsn, db, kind)
if err != nil {
return nil, cleanup, err
}

store, err := NewPostgresqlStoreFromSqlStore(ctx, store, newDsn, nil)
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, cleanup, err
}

return store, cleanup, nil
}

if kind == MysqlStoreEngine {
var cleanUp func()
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
var err error
cleanUp, err = testutil.CreateMysqlTestContainer()
_, err = testutil.CreateMysqlTestContainer()
if err != nil {
return nil, nil, err
}
Expand All @@ -386,31 +387,31 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store

db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
if err != nil {
return nil, cleanUp, fmt.Errorf("failed to open mysql connection: %v", err)
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
}

newDsn, cleanup, err := createRandomDB(dsn, db, kind)
dsn, cleanup, err = createRandomDB(dsn, db, kind)
if err != nil {
return nil, cleanup, err
}

store, err := NewMysqlStoreFromSqlStore(ctx, store, newDsn, nil)
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, nil, err
}

return store, cleanup, nil
}

closeConnection := func() {
cleanup()
store.Close(ctx)
}

return store, closeConnection, nil
}

func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%d", rand.Intn(1e6))
dbName := fmt.Sprintf("test_db_%d", rand.Intn(1e9))

if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
return "", nil, fmt.Errorf("failed to create database: %v", err)
Expand All @@ -433,6 +434,7 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err
}
if err != nil {
log.Errorf("failed to drop database %s: %v", dbName, err)
panic(err)
}
sqlDB, _ := db.DB()
_ = sqlDB.Close()
Expand Down

0 comments on commit 6328112

Please sign in to comment.