From 0a4aa10d32beaf89aece797845074d915225749c Mon Sep 17 00:00:00 2001 From: Subrose Date: Mon, 11 Dec 2023 17:28:35 +0000 Subject: [PATCH] sql working now --- vault/go.mod | 3 --- vault/sql.go | 67 ++++++++++++++++++++++------------------------------ 2 files changed, 28 insertions(+), 42 deletions(-) diff --git a/vault/go.mod b/vault/go.mod index fc8ae6b..b4d9035 100644 --- a/vault/go.mod +++ b/vault/go.mod @@ -13,9 +13,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/doug-martin/goqu v5.0.0+incompatible // indirect - github.com/doug-martin/goqu/v9 v9.19.0 // indirect - github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/huandu/go-sqlbuilder v1.24.0 // indirect diff --git a/vault/sql.go b/vault/sql.go index 5e4a872..fc089f3 100644 --- a/vault/sql.go +++ b/vault/sql.go @@ -10,8 +10,6 @@ import ( "regexp" "time" - "github.com/doug-martin/goqu/v9" - _ "github.com/doug-martin/goqu/v9/dialect/postgres" "github.com/lib/pq" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -19,27 +17,18 @@ import ( ) type SqlStore struct { - db *goqu.Database - gdb *gorm.DB + db *gorm.DB } func NewSqlStore(dsn string) (*SqlStore, error) { - pgDb, err := sql.Open("postgres", dsn) - if err != nil { - return nil, err - } - dialect := goqu.Dialect("postgres") - db := dialect.DB(pgDb) - time.Local = time.UTC - - gdb, err := gorm.Open(postgres.Open(dsn), &gorm.Config{TranslateError: true, Logger: logger.Default.LogMode(logger.Silent)}) - // gdb = gdb.Debug() + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{TranslateError: true, Logger: logger.Default.LogMode(logger.Silent)}) + // db = db.Debug() if err != nil { return nil, err } - store := &SqlStore{db, gdb} + store := &SqlStore{db} err = store.CreateSchemas() if err != nil { @@ -138,7 +127,7 @@ func (dbCollectionMetadata) TableName() string { func (st *SqlStore) CreateSchemas() error { // Use GORM's automigrate to create tables - err := st.gdb.AutoMigrate(&dbPrincipal{}, &dbPolicy{}, &dbPrincipalPolicy{}, &dbToken{}, &dbCollectionMetadata{}) + err := st.db.AutoMigrate(&dbPrincipal{}, &dbPolicy{}, &dbPrincipalPolicy{}, &dbToken{}, &dbCollectionMetadata{}) if err != nil { return err } @@ -153,7 +142,7 @@ func validateInput(input string) bool { func (st *SqlStore) CreateCollection(ctx context.Context, c *Collection) error { // Start a new transaction - tx := st.gdb.Begin() + tx := st.db.Begin() if tx.Error != nil { return tx.Error } @@ -220,7 +209,7 @@ func getCollectionFields(ctx context.Context, db *gorm.DB, collectionName string } func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, error) { - fields, err := getCollectionFields(ctx, st.gdb, name) + fields, err := getCollectionFields(ctx, st.db, name) if err != nil { return nil, err } @@ -229,7 +218,7 @@ func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) { var collectionMetadatas []dbCollectionMetadata - result := st.gdb.Find(&collectionMetadatas) + result := st.db.Find(&collectionMetadatas) if result.Error != nil { return nil, result.Error } @@ -244,7 +233,7 @@ func (st SqlStore) DeleteCollection(ctx context.Context, name string) error { if !validateInput(name) { return &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", name)} } - tx := st.gdb.Begin() + tx := st.db.Begin() if tx.Error != nil { return tx.Error } @@ -282,7 +271,7 @@ func (st SqlStore) CreateRecord(ctx context.Context, collectionName string, reco return "", &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)} } - fields, err := getCollectionFields(ctx, st.gdb, collectionName) + fields, err := getCollectionFields(ctx, st.db, collectionName) if err != nil { return "", err } @@ -306,7 +295,7 @@ func (st SqlStore) CreateRecord(ctx context.Context, collectionName string, reco } // Use gorm's Create method with the map - result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Create(&newRecord) + result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Create(&newRecord) if result.Error != nil { return "", result.Error } @@ -320,7 +309,7 @@ func (st SqlStore) GetRecords(ctx context.Context, collectionName string) ([]str } var recordIds []string - result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Pluck("id", &recordIds) + result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Pluck("id", &recordIds) if result.Error != nil { return nil, result.Error } @@ -334,7 +323,7 @@ func (st SqlStore) GetRecord(ctx context.Context, collectionName string, recordI } record := make(Record) - rows, err := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Select("*").Rows() + rows, err := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Select("*").Rows() if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, &NotFoundError{"record", recordID} @@ -379,7 +368,7 @@ func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, reco return &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)} } - fields, err := getCollectionFields(ctx, st.gdb, collectionName) + fields, err := getCollectionFields(ctx, st.db, collectionName) if err != nil { return err } @@ -402,7 +391,7 @@ func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, reco } // Use gorm's Update method with the map - result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Updates(newRecord) + result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Updates(newRecord) if result.Error != nil { return result.Error } @@ -418,7 +407,7 @@ func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, reco return &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)} } - result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Delete(&Record{}) + result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Delete(&Record{}) if result.Error != nil { return result.Error } @@ -431,7 +420,7 @@ func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, reco func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) { var dbPrincipal dbPrincipal - err := st.gdb.Preload("Policies").Where("username = ?", username).First(&dbPrincipal).Error + err := st.db.Preload("Policies").Where("username = ?", username).First(&dbPrincipal).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, &NotFoundError{"principal", username} @@ -457,7 +446,7 @@ func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principa } func (st SqlStore) CreatePrincipal(ctx context.Context, principal *Principal) error { - tx := st.gdb.Begin() + tx := st.db.Begin() defer func() { if r := recover(); r != nil { @@ -505,7 +494,7 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal *Principal) er } func (st SqlStore) DeletePrincipal(ctx context.Context, id string) error { - tx := st.gdb.Begin() + tx := st.db.Begin() defer func() { if r := recover(); r != nil { @@ -540,7 +529,7 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, id string) error { func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) { var dbPolicy dbPolicy - if err := st.gdb.Where("id = ?", policyId).First(&dbPolicy).Error; err != nil { + if err := st.db.Where("id = ?", policyId).First(&dbPolicy).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, &NotFoundError{"policy", policyId} } @@ -568,7 +557,7 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli } var dbPolicies []dbPolicy - if err := st.gdb.Where("id IN ?", policyIds).Find(&dbPolicies).Error; err != nil { + if err := st.db.Where("id IN ?", policyIds).Find(&dbPolicies).Error; err != nil { return nil, err } @@ -591,7 +580,7 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli } func (st SqlStore) CreatePolicy(ctx context.Context, p *Policy) error { - tx := st.gdb.Begin() + tx := st.db.Begin() defer func() { if r := recover(); r != nil { @@ -624,7 +613,7 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p *Policy) error { } func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error { - tx := st.gdb.Begin() + tx := st.db.Begin() defer func() { if r := recover(); r != nil { @@ -662,26 +651,26 @@ func (st SqlStore) CreateToken(ctx context.Context, tokenId string, value string Id: tokenId, Value: value, } - err := st.gdb.Create(&dbToken).Error + err := st.db.Create(&dbToken).Error return err } func (st SqlStore) DeleteToken(ctx context.Context, tokenId string) error { - err := st.gdb.Where("id = ?", tokenId).Delete(&dbToken{}).Error + err := st.db.Where("id = ?", tokenId).Delete(&dbToken{}).Error return err } func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) { var dbToken dbToken - err := st.gdb.Where("id = ?", tokenId).First(&dbToken).Error + err := st.db.Where("id = ?", tokenId).First(&dbToken).Error return dbToken.Value, err } func (st SqlStore) Flush(ctx context.Context) error { // Drop all tables - tables, err := st.gdb.Migrator().GetTables() + tables, err := st.db.Migrator().GetTables() for _, table := range tables { - if err := st.gdb.Migrator().DropTable(table); err != nil { + if err := st.db.Migrator().DropTable(table); err != nil { return err } }