diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index c884301b0ee..a6b376b5a28 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -193,61 +193,12 @@ jobs: if: matrix.store == 'mysql' run: docker pull mlsmaycon/warmed-mysql:8 - - name: Start Postgres - if: matrix.store == 'postgres' - run: | - docker run -d \ - -e POSTGRES_USER=root \ - -e POSTGRES_PASSWORD=netbird \ - -e POSTGRES_DB=netbird \ - -p 5432:5432 \ - --name my-postgres \ - postgres:15-alpine - - - name: Wait for Postgres - if: matrix.store == 'postgres' - run: | - for i in {1..10}; do - if nc -z localhost 5432; then - break - fi - echo "Waiting for Postgres..." - sleep 3 - done - - - name: Start MySQL - if: matrix.store == 'mysql' - run: | - docker run -d \ - -p 3306:3306 \ - --name my-mysql \ - mlsmaycon/warmed-mysql:8 - - - name: Wait for MySQL - if: matrix.store == 'mysql' - run: | - for i in {1..10}; do - if nc -z localhost 3306; then - break - fi - echo "Waiting for MySQL..." - sleep 3 - done - - name: Test - run: | - if [ "${{ matrix.store }}" = "postgres" ]; then - export NETBIRD_STORE_ENGINE_POSTGRES_DSN="postgres://root:netbird@localhost:5432/netbird?sslmode=disable" - fi - - if [ "${{ matrix.store }}" = "mysql" ]; then - export NETBIRD_STORE_ENGINE_MYSQL_DSN="root:testing@tcp(localhost:3306)/testing" - fi - + run: | CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \ go test -tags=devcert -p 1 \ - -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,NETBIRD_STORE_ENGINE_POSTGRES_DSN,NETBIRD_STORE_ENGINE_MYSQL_DSN" \ + -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \ -timeout 10m $(go list ./... | grep /management) benchmark: diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 7b42cf55059..5f7b5801d58 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -24,7 +24,6 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" route2 "github.com/netbirdio/netbird/route" @@ -35,38 +34,6 @@ import ( nbroute "github.com/netbirdio/netbird/route" ) -func TestMain(m *testing.M) { - var cleanUpPostgres func() - var cleanUpMysql func() - - if dsn, ok := os.LookupEnv("NETBIRD_STORE_ENGINE_POSTGRES_DSN"); !ok || dsn == "" { - var err error - cleanUpPostgres, err = testutil.CreatePostgresTestContainer() - if err != nil { - os.Exit(1) - } - } - - if dsn, ok := os.LookupEnv("NETBIRD_STORE_ENGINE_MYSQL_DSN"); !ok || dsn == "" { - var err error - cleanUpMysql, err = testutil.CreateMysqlTestContainer() - if err != nil { - os.Exit(1) - } - } - - code := m.Run() - - if cleanUpPostgres != nil { - cleanUpPostgres() - } - if cleanUpMysql != nil { - cleanUpMysql() - } - - os.Exit(code) -} - var engines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine} func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { diff --git a/management/server/store/store.go b/management/server/store/store.go index 70c925a9968..f5049fb7468 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -338,28 +338,25 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { if kind == PostgresStoreEngine { - var cleanUp func() - removeContainer := false if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" { var err error - cleanUp, err = testutil.CreatePostgresTestContainer() + _, err = testutil.CreatePostgresTestContainer() if err != nil { return nil, nil, err } - removeContainer = true } dsn, ok := os.LookupEnv(postgresDsnEnv) if !ok { - return nil, cleanUp, fmt.Errorf("%s is not set", postgresDsnEnv) + return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { - return nil, cleanUp, fmt.Errorf("failed to open postgres connection: %v", err) + return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } - newDsn, cleanup, err := createRandomDB(dsn, db, cleanUp, kind, removeContainer) + newDsn, cleanup, err := createRandomDB(dsn, db, kind) if err != nil { return nil, cleanup, err } @@ -374,14 +371,12 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store if kind == MysqlStoreEngine { var cleanUp func() - removeContainer := false if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" { var err error cleanUp, err = testutil.CreateMysqlTestContainer() if err != nil { return nil, nil, err } - removeContainer = true } dsn, ok := os.LookupEnv(mysqlDsnEnv) @@ -394,7 +389,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store return nil, cleanUp, fmt.Errorf("failed to open mysql connection: %v", err) } - newDsn, cleanup, err := createRandomDB(dsn, db, cleanUp, kind, removeContainer) + newDsn, cleanup, err := createRandomDB(dsn, db, kind) if err != nil { return nil, cleanup, err } @@ -414,16 +409,16 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store return store, closeConnection, nil } -func createRandomDB(dsn string, db *gorm.DB, cleanUp func(), engine Engine, removeContainer bool) (string, func(), error) { +func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) { dbName := fmt.Sprintf("test_db_%d", rand.Intn(1e6)) if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil { - return "", cleanUp, fmt.Errorf("failed to create database: %v", err) + return "", nil, fmt.Errorf("failed to create database: %v", err) } u, err := url.Parse(dsn) if err != nil { - return "", cleanUp, fmt.Errorf("failed to parse DSN: %v", err) + return "", nil, fmt.Errorf("failed to parse DSN: %v", err) } u.Path = dbName @@ -431,21 +426,53 @@ func createRandomDB(dsn string, db *gorm.DB, cleanUp func(), engine Engine, remo cleanup := func() { switch engine { case PostgresStoreEngine: - db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)) + err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error case MysqlStoreEngine: - db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)) + err = killMySQLConnections(dsn, dbName) + } + if err != nil { + log.Errorf("failed to drop database %s: %v", dbName, err) } sqlDB, _ := db.DB() _ = sqlDB.Close() - if cleanUp != nil && removeContainer { - cleanUp() - } } return u.String(), cleanup, nil } +func killMySQLConnections(dsn, targetDB string) error { + u, err := url.Parse(dsn) + if err != nil { + return fmt.Errorf("failed to parse DSN: %v", err) + } + + u.Path = "testing" + + ctrlDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + return fmt.Errorf("failed to open mysql connection: %v", err) + } + + var procs []struct { + ID int + } + query := "SELECT ID FROM INFORMATION_SCHEMA.PROCESSLIST WHERE DB = ?" + if err := ctrlDB.Raw(query, targetDB).Scan(&procs).Error; err != nil { + return fmt.Errorf("failed to get processes: %v", err) + } + + for _, p := range procs { + killStmt := fmt.Sprintf("KILL %d", p.ID) + if err := ctrlDB.Exec(killStmt).Error; err != nil { + return fmt.Errorf("failed to kill process %d: %v", p.ID, err) + } + } + + dropStmt := fmt.Sprintf("DROP DATABASE `%s`", targetDB) + return ctrlDB.Exec(dropStmt).Error +} + func loadSQL(db *gorm.DB, filepath string) error { sqlContent, err := os.ReadFile(filepath) if err != nil {