| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 | // Copyright 2017 The casbin Authors. All Rights Reserved.//// Licensed under the Apache License, Version 2.0 (the "License");// you may not use this file except in compliance with the License.// You may obtain a copy of the License at////      http://www.apache.org/licenses/LICENSE-2.0//// Unless required by applicable law or agreed to in writing, software// distributed under the License is distributed on an "AS IS" BASIS,// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.// See the License for the specific language governing permissions and// limitations under the License.package mycasbinimport (	"errors"	"fmt"	"strings"	"gorm.io/driver/mysql"	"gorm.io/gorm"	"github.com/casbin/casbin/v2/model"	"github.com/casbin/casbin/v2/persist")const (	defaultDatabaseName = "casbin"	defaultTableName    = "sys_casbin_rule")type customTableKey struct{}type CasbinRule struct {	PType string `gorm:"size:100"`	V0    string `gorm:"size:100"`	V1    string `gorm:"size:100"`	V2    string `gorm:"size:100"`	V3    string `gorm:"size:100"`	V4    string `gorm:"size:100"`	V5    string `gorm:"size:100"`}func (CasbinRule) TableName() string {	return "sys_casbin_rule"}type Filter struct {	PType []string	V0    []string	V1    []string	V2    []string	V3    []string	V4    []string	V5    []string}// Adapter represents the Gorm adapter for policy storage.type Adapter struct {	dataSourceName string	databaseName   string	tablePrefix    string	tableName      string	db             *gorm.DB	isFiltered     bool}// NewAdapterByDBUseTableName creates gorm-adapter by an existing Gorm instance and the specified table prefix and table name// Example: gormadapter.NewAdapterByDBUseTableName(&db, "cms", "casbin") Automatically generate table name like this "cms_casbin"func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*Adapter, error) {	if len(tableName) == 0 {		tableName = defaultTableName	}	a := &Adapter{		tablePrefix: prefix,		tableName:   tableName,	}	a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})	err := a.createTable()	if err != nil {		return nil, err	}	return a, nil}// NewAdapterByDB creates gorm-adapter by an existing Gorm instancefunc NewAdapterByDB(db *gorm.DB) (*Adapter, error) {	return NewAdapterByDBUseTableName(db, "", defaultTableName)}func openDBConnection(dataSourceName string) (*gorm.DB, error) {	return gorm.Open(mysql.Open(dataSourceName), &gorm.Config{})}func (a *Adapter) open() error {	var err error	var db *gorm.DB	db, err = openDBConnection(a.dataSourceName)	if err != nil {		return err	}	a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{})	return a.createTable()}func (a *Adapter) close() error {	a.db = nil	return nil}// getTableInstance return the dynamic table namefunc (a *Adapter) getTableInstance() *CasbinRule {	return &CasbinRule{}}func (a *Adapter) getFullTableName() string {	if a.tablePrefix != "" {		return a.tablePrefix + "_" + a.tableName	}	return a.tableName}func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB {	return func(db *gorm.DB) *gorm.DB {		tableName := a.getFullTableName()		return db.Table(tableName)	}}func (a *Adapter) createTable() error {	t := a.db.Statement.Context.Value(customTableKey{})	if t == nil {		t = a.getTableInstance()	}	if err := a.db.AutoMigrate(t); err != nil {		return err	}	tableName := a.getFullTableName()	index := "idx_" + tableName	hasIndex := a.db.Migrator().HasIndex(t, index)	if !hasIndex {		if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (p_type,v0,v1,v2,v3,v4,v5)", index, tableName)).Error; err != nil {			return err		}	}	return nil}func (a *Adapter) dropTable() error {	t := a.db.Statement.Context.Value(customTableKey{})	if t == nil {		return a.db.Migrator().DropTable(a.getTableInstance())	}	return a.db.Migrator().DropTable(t)}func loadPolicyLine(line CasbinRule, model model.Model) {	var p = []string{line.PType,		line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}	var lineText string	if line.V5 != "" {		lineText = strings.Join(p, ", ")	} else if line.V4 != "" {		lineText = strings.Join(p[:6], ", ")	} else if line.V3 != "" {		lineText = strings.Join(p[:5], ", ")	} else if line.V2 != "" {		lineText = strings.Join(p[:4], ", ")	} else if line.V1 != "" {		lineText = strings.Join(p[:3], ", ")	} else if line.V0 != "" {		lineText = strings.Join(p[:2], ", ")	}	persist.LoadPolicyLine(lineText, model)}// LoadPolicy loads policy from database.func (a *Adapter) LoadPolicy(model model.Model) error {	var lines []CasbinRule	if err := a.db.Find(&lines).Error; err != nil {		return err	}	for _, line := range lines {		loadPolicyLine(line, model)	}	return nil}// LoadFilteredPolicy loads only policy rules that match the filter.func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {	var lines []CasbinRule	filterValue, ok := filter.(Filter)	if !ok {		return errors.New("invalid filter type")	}	if err := a.db.Scopes(a.filterQuery(a.db, filterValue)).Find(&lines).Error; err != nil {		return err	}	for _, line := range lines {		loadPolicyLine(line, model)	}	a.isFiltered = true	return nil}// IsFiltered returns true if the loaded policy has been filtered.func (a *Adapter) IsFiltered() bool {	return a.isFiltered}// filterQuery builds the gorm query to match the rule filter to use within a scope.func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gorm.DB {	return func(db *gorm.DB) *gorm.DB {		if len(filter.PType) > 0 {			db = db.Where("p_type in (?)", filter.PType)		}		if len(filter.V0) > 0 {			db = db.Where("v0 in (?)", filter.V0)		}		if len(filter.V1) > 0 {			db = db.Where("v1 in (?)", filter.V1)		}		if len(filter.V2) > 0 {			db = db.Where("v2 in (?)", filter.V2)		}		if len(filter.V3) > 0 {			db = db.Where("v3 in (?)", filter.V3)		}		if len(filter.V4) > 0 {			db = db.Where("v4 in (?)", filter.V4)		}		if len(filter.V5) > 0 {			db = db.Where("v5 in (?)", filter.V5)		}		return db	}}func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {	line := a.getTableInstance()	line.PType = ptype	if len(rule) > 0 {		line.V0 = rule[0]	}	if len(rule) > 1 {		line.V1 = rule[1]	}	if len(rule) > 2 {		line.V2 = rule[2]	}	if len(rule) > 3 {		line.V3 = rule[3]	}	if len(rule) > 4 {		line.V4 = rule[4]	}	if len(rule) > 5 {		line.V5 = rule[5]	}	return *line}// SavePolicy saves policy to database.func (a *Adapter) SavePolicy(model model.Model) error {	err := a.dropTable()	if err != nil {		return err	}	err = a.createTable()	if err != nil {		return err	}	for ptype, ast := range model["p"] {		for _, rule := range ast.Policy {			line := a.savePolicyLine(ptype, rule)			err := a.db.Create(&line).Error			if err != nil {				return err			}		}	}	for ptype, ast := range model["g"] {		for _, rule := range ast.Policy {			line := a.savePolicyLine(ptype, rule)			err := a.db.Create(&line).Error			if err != nil {				return err			}		}	}	return nil}// AddPolicy adds a policy rule to the storage.func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {	line := a.savePolicyLine(ptype, rule)	err := a.db.Create(&line).Error	return err}// RemovePolicy removes a policy rule from the storage.func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {	line := a.savePolicyLine(ptype, rule)	err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete	return err}// AddPolicies adds multiple policy rules to the storage.func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error {	return a.db.Transaction(func(tx *gorm.DB) error {		for _, rule := range rules {			line := a.savePolicyLine(ptype, rule)			if err := tx.Create(&line).Error; err != nil {				return err			}		}		return nil	})}// RemovePolicies removes multiple policy rules from the storage.func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error {	return a.db.Transaction(func(tx *gorm.DB) error {		for _, rule := range rules {			line := a.savePolicyLine(ptype, rule)			if err := a.rawDelete(tx, line); err != nil { //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete				return err			}		}		return nil	})}// RemoveFilteredPolicy removes policy rules that match the filter from the storage.func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {	line := a.getTableInstance()	line.PType = ptype	if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {		line.V0 = fieldValues[0-fieldIndex]	}	if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {		line.V1 = fieldValues[1-fieldIndex]	}	if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {		line.V2 = fieldValues[2-fieldIndex]	}	if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {		line.V3 = fieldValues[3-fieldIndex]	}	if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {		line.V4 = fieldValues[4-fieldIndex]	}	if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {		line.V5 = fieldValues[5-fieldIndex]	}	err := a.rawDelete(a.db, *line)	return err}func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {	queryArgs := []interface{}{line.PType}	queryStr := "p_type = ?"	if line.V0 != "" {		queryStr += " and v0 = ?"		queryArgs = append(queryArgs, line.V0)	}	if line.V1 != "" {		queryStr += " and v1 = ?"		queryArgs = append(queryArgs, line.V1)	}	if line.V2 != "" {		queryStr += " and v2 = ?"		queryArgs = append(queryArgs, line.V2)	}	if line.V3 != "" {		queryStr += " and v3 = ?"		queryArgs = append(queryArgs, line.V3)	}	if line.V4 != "" {		queryStr += " and v4 = ?"		queryArgs = append(queryArgs, line.V4)	}	if line.V5 != "" {		queryStr += " and v5 = ?"		queryArgs = append(queryArgs, line.V5)	}	args := append([]interface{}{queryStr}, queryArgs...)	err := db.Delete(a.getTableInstance(), args...).Error	return err}func appendWhere(line CasbinRule) (string, []interface{}) {	queryArgs := []interface{}{line.PType}	queryStr := "p_type = ?"	if line.V0 != "" {		queryStr += " and v0 = ?"		queryArgs = append(queryArgs, line.V0)	}	if line.V1 != "" {		queryStr += " and v1 = ?"		queryArgs = append(queryArgs, line.V1)	}	if line.V2 != "" {		queryStr += " and v2 = ?"		queryArgs = append(queryArgs, line.V2)	}	if line.V3 != "" {		queryStr += " and v3 = ?"		queryArgs = append(queryArgs, line.V3)	}	if line.V4 != "" {		queryStr += " and v4 = ?"		queryArgs = append(queryArgs, line.V4)	}	if line.V5 != "" {		queryStr += " and v5 = ?"		queryArgs = append(queryArgs, line.V5)	}	return queryStr, queryArgs}// UpdatePolicy updates a new policy rule to DB.func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []string) error {	oldLine := a.savePolicyLine(ptype, oldRule)	queryStr, queryArgs := appendWhere(oldLine)	newLine := a.savePolicyLine(ptype, newPolicy)	err := a.db.Where(queryStr, queryArgs...).Updates(newLine).Error	return err}
 |