Skip to content

Commit

Permalink
sql working now
Browse files Browse the repository at this point in the history
  • Loading branch information
subroseio committed Dec 11, 2023
1 parent b50dca6 commit 0a4aa10
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 42 deletions.
3 changes: 0 additions & 3 deletions vault/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ require (

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/doug-martin/goqu v5.0.0+incompatible // indirect
github.com/doug-martin/goqu/v9 v9.19.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/huandu/go-sqlbuilder v1.24.0 // indirect
Expand Down
67 changes: 28 additions & 39 deletions vault/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,25 @@ import (
"regexp"
"time"

"github.com/doug-martin/goqu/v9"
_ "github.com/doug-martin/goqu/v9/dialect/postgres"
"github.com/lib/pq"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

type SqlStore struct {
db *goqu.Database
gdb *gorm.DB
db *gorm.DB
}

func NewSqlStore(dsn string) (*SqlStore, error) {
pgDb, err := sql.Open("postgres", dsn)
if err != nil {
return nil, err
}
dialect := goqu.Dialect("postgres")
db := dialect.DB(pgDb)

time.Local = time.UTC

gdb, err := gorm.Open(postgres.Open(dsn), &gorm.Config{TranslateError: true, Logger: logger.Default.LogMode(logger.Silent)})
// gdb = gdb.Debug()
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{TranslateError: true, Logger: logger.Default.LogMode(logger.Silent)})
// db = db.Debug()
if err != nil {
return nil, err
}

store := &SqlStore{db, gdb}
store := &SqlStore{db}

err = store.CreateSchemas()
if err != nil {
Expand Down Expand Up @@ -138,7 +127,7 @@ func (dbCollectionMetadata) TableName() string {

func (st *SqlStore) CreateSchemas() error {
// Use GORM's automigrate to create tables
err := st.gdb.AutoMigrate(&dbPrincipal{}, &dbPolicy{}, &dbPrincipalPolicy{}, &dbToken{}, &dbCollectionMetadata{})
err := st.db.AutoMigrate(&dbPrincipal{}, &dbPolicy{}, &dbPrincipalPolicy{}, &dbToken{}, &dbCollectionMetadata{})
if err != nil {
return err
}
Expand All @@ -153,7 +142,7 @@ func validateInput(input string) bool {

func (st *SqlStore) CreateCollection(ctx context.Context, c *Collection) error {
// Start a new transaction
tx := st.gdb.Begin()
tx := st.db.Begin()
if tx.Error != nil {
return tx.Error
}
Expand Down Expand Up @@ -220,7 +209,7 @@ func getCollectionFields(ctx context.Context, db *gorm.DB, collectionName string
}

func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection, error) {
fields, err := getCollectionFields(ctx, st.gdb, name)
fields, err := getCollectionFields(ctx, st.db, name)
if err != nil {
return nil, err
}
Expand All @@ -229,7 +218,7 @@ func (st SqlStore) GetCollection(ctx context.Context, name string) (*Collection,

func (st SqlStore) GetCollections(ctx context.Context) ([]string, error) {
var collectionMetadatas []dbCollectionMetadata
result := st.gdb.Find(&collectionMetadatas)
result := st.db.Find(&collectionMetadatas)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -244,7 +233,7 @@ func (st SqlStore) DeleteCollection(ctx context.Context, name string) error {
if !validateInput(name) {
return &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", name)}
}
tx := st.gdb.Begin()
tx := st.db.Begin()
if tx.Error != nil {
return tx.Error
}
Expand Down Expand Up @@ -282,7 +271,7 @@ func (st SqlStore) CreateRecord(ctx context.Context, collectionName string, reco
return "", &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)}
}

fields, err := getCollectionFields(ctx, st.gdb, collectionName)
fields, err := getCollectionFields(ctx, st.db, collectionName)
if err != nil {
return "", err
}
Expand All @@ -306,7 +295,7 @@ func (st SqlStore) CreateRecord(ctx context.Context, collectionName string, reco
}

// Use gorm's Create method with the map
result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Create(&newRecord)
result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Create(&newRecord)
if result.Error != nil {
return "", result.Error
}
Expand All @@ -320,7 +309,7 @@ func (st SqlStore) GetRecords(ctx context.Context, collectionName string) ([]str
}

var recordIds []string
result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Pluck("id", &recordIds)
result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Pluck("id", &recordIds)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -334,7 +323,7 @@ func (st SqlStore) GetRecord(ctx context.Context, collectionName string, recordI
}

record := make(Record)
rows, err := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Select("*").Rows()
rows, err := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Select("*").Rows()
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, &NotFoundError{"record", recordID}
Expand Down Expand Up @@ -379,7 +368,7 @@ func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, reco
return &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)}
}

fields, err := getCollectionFields(ctx, st.gdb, collectionName)
fields, err := getCollectionFields(ctx, st.db, collectionName)
if err != nil {
return err
}
Expand All @@ -402,7 +391,7 @@ func (st SqlStore) UpdateRecord(ctx context.Context, collectionName string, reco
}

// Use gorm's Update method with the map
result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Updates(newRecord)
result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Updates(newRecord)
if result.Error != nil {
return result.Error
}
Expand All @@ -418,7 +407,7 @@ func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, reco
return &ValueError{Msg: fmt.Sprintf("Invalid collection name %s", collectionName)}
}

result := st.gdb.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Delete(&Record{})
result := st.db.Table(fmt.Sprintf("collection_%s", collectionName)).Where("id = ?", recordID).Delete(&Record{})
if result.Error != nil {
return result.Error
}
Expand All @@ -431,7 +420,7 @@ func (st SqlStore) DeleteRecord(ctx context.Context, collectionName string, reco

func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principal, error) {
var dbPrincipal dbPrincipal
err := st.gdb.Preload("Policies").Where("username = ?", username).First(&dbPrincipal).Error
err := st.db.Preload("Policies").Where("username = ?", username).First(&dbPrincipal).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, &NotFoundError{"principal", username}
Expand All @@ -457,7 +446,7 @@ func (st SqlStore) GetPrincipal(ctx context.Context, username string) (*Principa
}

func (st SqlStore) CreatePrincipal(ctx context.Context, principal *Principal) error {
tx := st.gdb.Begin()
tx := st.db.Begin()

defer func() {
if r := recover(); r != nil {
Expand Down Expand Up @@ -505,7 +494,7 @@ func (st SqlStore) CreatePrincipal(ctx context.Context, principal *Principal) er
}

func (st SqlStore) DeletePrincipal(ctx context.Context, id string) error {
tx := st.gdb.Begin()
tx := st.db.Begin()

defer func() {
if r := recover(); r != nil {
Expand Down Expand Up @@ -540,7 +529,7 @@ func (st SqlStore) DeletePrincipal(ctx context.Context, id string) error {
func (st SqlStore) GetPolicy(ctx context.Context, policyId string) (*Policy, error) {
var dbPolicy dbPolicy

if err := st.gdb.Where("id = ?", policyId).First(&dbPolicy).Error; err != nil {
if err := st.db.Where("id = ?", policyId).First(&dbPolicy).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, &NotFoundError{"policy", policyId}
}
Expand Down Expand Up @@ -568,7 +557,7 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli
}

var dbPolicies []dbPolicy
if err := st.gdb.Where("id IN ?", policyIds).Find(&dbPolicies).Error; err != nil {
if err := st.db.Where("id IN ?", policyIds).Find(&dbPolicies).Error; err != nil {
return nil, err
}

Expand All @@ -591,7 +580,7 @@ func (st SqlStore) GetPolicies(ctx context.Context, policyIds []string) ([]*Poli
}

func (st SqlStore) CreatePolicy(ctx context.Context, p *Policy) error {
tx := st.gdb.Begin()
tx := st.db.Begin()

defer func() {
if r := recover(); r != nil {
Expand Down Expand Up @@ -624,7 +613,7 @@ func (st SqlStore) CreatePolicy(ctx context.Context, p *Policy) error {
}

func (st SqlStore) DeletePolicy(ctx context.Context, policyID string) error {
tx := st.gdb.Begin()
tx := st.db.Begin()

defer func() {
if r := recover(); r != nil {
Expand Down Expand Up @@ -662,26 +651,26 @@ func (st SqlStore) CreateToken(ctx context.Context, tokenId string, value string
Id: tokenId,
Value: value,
}
err := st.gdb.Create(&dbToken).Error
err := st.db.Create(&dbToken).Error
return err
}

func (st SqlStore) DeleteToken(ctx context.Context, tokenId string) error {
err := st.gdb.Where("id = ?", tokenId).Delete(&dbToken{}).Error
err := st.db.Where("id = ?", tokenId).Delete(&dbToken{}).Error
return err
}

func (st SqlStore) GetTokenValue(ctx context.Context, tokenId string) (string, error) {
var dbToken dbToken
err := st.gdb.Where("id = ?", tokenId).First(&dbToken).Error
err := st.db.Where("id = ?", tokenId).First(&dbToken).Error
return dbToken.Value, err
}

func (st SqlStore) Flush(ctx context.Context) error {
// Drop all tables
tables, err := st.gdb.Migrator().GetTables()
tables, err := st.db.Migrator().GetTables()
for _, table := range tables {
if err := st.gdb.Migrator().DropTable(table); err != nil {
if err := st.db.Migrator().DropTable(table); err != nil {
return err
}
}
Expand Down

0 comments on commit 0a4aa10

Please sign in to comment.