Skip to content
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 288 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ Based on [Officially Supported Databases](https://v1.gorm.io/docs/connecting_to_
- PostgreSQL
- SQL Server
- Sqlite3
> Since `` sqlite`` needs ``cgo`` support, it is only supported in branch ``sqlite`` instead of ``master``.See detail: [gorm-adapter#93](https://github.com/casbin/gorm-adapter/issues/93). If you want to use it, maybe you need to merge ``sqlite`` to ``master`` branch manually
> gorm-adapter use ``github.com/glebarez/sqlite`` instead of gorm official sqlite driver ``gorm.io/driver/sqlite`` because the latter needs ``cgo`` support. But there is almost no difference between the two driver. If there is a difference in use, please submit an issue.
You may find other 3rd-party supported DBs in Gorm website or other places.
- other 3rd-party supported DBs in Gorm website or other places.

## Installation

Expand Down Expand Up @@ -69,7 +69,16 @@ func main() {
e.SavePolicy()
}
```

## Turn off AutoMigrate
New an adapter will use ``AutoMigrate`` by default for create table, if you want to turn it off, please use API ``TurnOffAutoMigrate(db *gorm.DB) *gorm.DB``. See example:
```go
db, err := gorm.Open(mysql.Open("root:@tcp(127.0.0.1:3306)/casbin"), &gorm.Config{})
TurnOffAutoMigrate(db)
// a,_ := NewAdapterByDB(...)
// a,_ := NewAdapterByDBUseTableName(...)
a,_ := NewAdapterByDBWithCustomTable(...)
```
Find out more details at [gorm-adapter#162](https://github.com/casbin/gorm-adapter/issues/162)
## Customize table columns example
You can change the gorm struct tags, but the table structure must stay the same.
```go
Expand Down
112 changes: 97 additions & 15 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ import (

const (
defaultDatabaseName = "casbin"
defaultTableName = "sys_casbin_rule"
defaultTableName = "casbin_rule"
)

type customTableKey struct{}
const disableMigrateKey = "disableMigrateKey"
const customTableKey = "customTableKey"

type CasbinRule struct {
ID uint `gorm:"primaryKey;autoIncrement"`
Expand All @@ -31,6 +32,8 @@ type CasbinRule struct {
V3 string `gorm:"size:100"`
V4 string `gorm:"size:100"`
V5 string `gorm:"size:100"`
V6 string `gorm:"size:25"`
V7 string `gorm:"size:25"`
}

func (CasbinRule) TableName() string {
Expand All @@ -45,6 +48,8 @@ type Filter struct {
V3 []string
V4 []string
V5 []string
V6 []string
V7 []string
}

// Adapter represents the Gorm adapter for policy storage.
Expand Down Expand Up @@ -174,9 +179,10 @@ func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*
}

a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})

err := a.createTable()
if err != nil {
return nil, err
return a, err
}

return a, nil
Expand Down Expand Up @@ -228,13 +234,24 @@ func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
return NewAdapterByDBUseTableName(db, "", defaultTableName)
}

func TurnOffAutoMigrate(db *gorm.DB) {
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}

ctx = context.WithValue(ctx, disableMigrateKey, false)

*db = *db.WithContext(ctx)
}

func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) {
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}

ctx = context.WithValue(ctx, customTableKey{}, t)
ctx = context.WithValue(ctx, customTableKey, t)

curTableName := defaultTableName
if len(tableName) > 0 {
Expand Down Expand Up @@ -333,7 +350,12 @@ func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB {
}

func (a *Adapter) createTable() error {
t := a.db.Statement.Context.Value(customTableKey{})
disableMigrate := a.db.Statement.Context.Value(disableMigrateKey)
if disableMigrate != nil {
return nil
}

t := a.db.Statement.Context.Value(customTableKey)

if t != nil {
return a.db.AutoMigrate(t)
Expand All @@ -348,26 +370,34 @@ func (a *Adapter) createTable() error {
index := strings.ReplaceAll("idx_"+tableName, ".", "_")
hasIndex := a.db.Migrator().HasIndex(t, index)
if !hasIndex {
if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5)", index, tableName)).Error; err != nil {
if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5,v6,v7)", index, tableName)).Error; err != nil {
return err
}
}
return nil
}

func (a *Adapter) dropTable() error {
t := a.db.Statement.Context.Value(customTableKey{})
t := a.db.Statement.Context.Value(customTableKey)
if t == nil {
return a.db.Migrator().DropTable(a.getTableInstance())
}

return a.db.Migrator().DropTable(t)
}

func (a *Adapter) truncateTable() error {
if a.db.Config.Name() == sqlite.DriverName {
return a.db.Exec(fmt.Sprintf("delete from %s", a.getFullTableName())).Error
}
return a.db.Exec(fmt.Sprintf("truncate table %s", a.getFullTableName())).Error
}

func loadPolicyLine(line CasbinRule, model model.Model) {
var p = []string{line.Ptype,
line.V0, line.V1, line.V2,
line.V3, line.V4, line.V5}
line.V3, line.V4, line.V5,
line.V6, line.V7}

index := len(p) - 1
for p[index] == "" {
Expand Down Expand Up @@ -443,6 +473,12 @@ func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gor
if len(filter.V5) > 0 {
db = db.Where("v5 in (?)", filter.V5)
}
if len(filter.V6) > 0 {
db = db.Where("v6 in (?)", filter.V6)
}
if len(filter.V7) > 0 {
db = db.Where("v7 in (?)", filter.V7)
}
return db
}
}
Expand All @@ -469,17 +505,19 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {
if len(rule) > 5 {
line.V5 = rule[5]
}
if len(rule) > 6 {
line.V6 = rule[6]
}
if len(rule) > 7 {
line.V7 = rule[7]
}

return *line
}

// SavePolicy saves policy to database.
func (a *Adapter) SavePolicy(model model.Model) error {
err := a.dropTable()
if err != nil {
return err
}
err = a.createTable()
err := a.truncateTable()
if err != nil {
return err
}
Expand Down Expand Up @@ -509,8 +547,10 @@ func (a *Adapter) SavePolicy(model model.Model) error {
}
}
}
if err := a.db.Create(&lines).Error; err != nil {
return err
if len(lines) > 0 {
if err := a.db.Create(&lines).Error; err != nil {
return err
}
}

return nil
Expand Down Expand Up @@ -586,6 +626,12 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
}
if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
line.V6 = fieldValues[6-fieldIndex]
}
if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
line.V7 = fieldValues[7-fieldIndex]
}
err = a.rawDelete(a.db, *line)
return err
}
Expand Down Expand Up @@ -628,6 +674,14 @@ func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, line.V5)
}
if line.V6 != "" {
queryStr += " and v6 = ?"
queryArgs = append(queryArgs, line.V6)
}
if line.V7 != "" {
queryStr += " and v7 = ?"
queryArgs = append(queryArgs, line.V7)
}
args := append([]interface{}{queryStr}, queryArgs...)
err := db.Delete(a.getTableInstance(), args...).Error
return err
Expand Down Expand Up @@ -661,6 +715,14 @@ func appendWhere(line CasbinRule) (string, []interface{}) {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, line.V5)
}
if line.V6 != "" {
queryStr += " and v6 = ?"
queryArgs = append(queryArgs, line.V6)
}
if line.V7 != "" {
queryStr += " and v7 = ?"
queryArgs = append(queryArgs, line.V7)
}
return queryStr, queryArgs
}

Expand Down Expand Up @@ -713,6 +775,12 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
}
if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
line.V6 = fieldValues[6-fieldIndex]
}
if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
line.V7 = fieldValues[7-fieldIndex]
}

newP := make([]CasbinRule, 0, len(newPolicies))
oldP := make([]CasbinRule, 0)
Expand Down Expand Up @@ -775,6 +843,14 @@ func (c *CasbinRule) queryString() (interface{}, []interface{}) {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, c.V5)
}
if c.V6 != "" {
queryStr += " and v6 = ?"
queryArgs = append(queryArgs, c.V6)
}
if c.V7 != "" {
queryStr += " and v7 = ?"
queryArgs = append(queryArgs, c.V7)
}

return queryStr, queryArgs
}
Expand Down Expand Up @@ -802,5 +878,11 @@ func (c *CasbinRule) toStringPolicy() []string {
if c.V5 != "" {
policy = append(policy, c.V5)
}
if c.V6 != "" {
policy = append(policy, c.V6)
}
if c.V7 != "" {
policy = append(policy, c.V7)
}
return policy
}
Loading

0 comments on commit 096f0ef

Please sign in to comment.