diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 6fb9f6a29a9..150953f51c7 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -129,7 +129,7 @@ func TestSaveDNSSettings(t *testing.T) { account, err := initTestDNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Errorf("failed to init testing account: %v", err) } err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 5f7b5801d58..f8e9bf0c0df 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -34,11 +34,9 @@ import ( nbroute "github.com/netbirdio/netbird/route" ) -var engines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine} - func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { t.Helper() - for _, engine := range engines { + for _, engine := range supportedEngines { if os.Getenv("NETBIRD_STORE_ENGINE") != "" && os.Getenv("NETBIRD_STORE_ENGINE") != string(engine) { continue } diff --git a/management/server/store/store.go b/management/server/store/store.go index 427921db3bd..c3959f88f98 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -12,6 +12,7 @@ import ( "path" "path/filepath" "runtime" + "slices" "strings" "time" @@ -181,6 +182,8 @@ const ( mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN" ) +var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine} + func getStoreEngineFromEnv() Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") @@ -189,7 +192,7 @@ func getStoreEngineFromEnv() Engine { } value := Engine(strings.ToLower(kind)) - if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine { + if slices.Contains(supportedEngines, value) { return value } @@ -337,6 +340,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( } func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { + var cleanup func() if kind == PostgresStoreEngine { if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" { var err error @@ -356,24 +360,21 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } - newDsn, cleanup, err := createRandomDB(dsn, db, kind) + dsn, cleanup, err = createRandomDB(dsn, db, kind) if err != nil { return nil, cleanup, err } - store, err := NewPostgresqlStoreFromSqlStore(ctx, store, newDsn, nil) + store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) if err != nil { return nil, cleanup, err } - - return store, cleanup, nil } if kind == MysqlStoreEngine { - var cleanUp func() if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" { var err error - cleanUp, err = testutil.CreateMysqlTestContainer() + _, err = testutil.CreateMysqlTestContainer() if err != nil { return nil, nil, err } @@ -386,23 +387,23 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) if err != nil { - return nil, cleanUp, fmt.Errorf("failed to open mysql connection: %v", err) + return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) } - newDsn, cleanup, err := createRandomDB(dsn, db, kind) + dsn, cleanup, err = createRandomDB(dsn, db, kind) if err != nil { return nil, cleanup, err } - store, err := NewMysqlStoreFromSqlStore(ctx, store, newDsn, nil) + store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil) if err != nil { return nil, nil, err } - return store, cleanup, nil } closeConnection := func() { + cleanup() store.Close(ctx) } @@ -410,7 +411,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store } func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) { - dbName := fmt.Sprintf("test_db_%d", rand.Intn(1e6)) + dbName := fmt.Sprintf("test_db_%d", rand.Intn(1e9)) if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil { return "", nil, fmt.Errorf("failed to create database: %v", err) @@ -433,6 +434,7 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err } if err != nil { log.Errorf("failed to drop database %s: %v", dbName, err) + panic(err) } sqlDB, _ := db.DB() _ = sqlDB.Close()