From 56df1ca2fe589a4da43746d063e7bd6feda77571 Mon Sep 17 00:00:00 2001 From: shenghui Date: Thu, 25 Jan 2024 10:44:47 +0800 Subject: [PATCH] Add new files and functions --- coord/location.go | 4 - db/db.go => db.go | 59 +- db/transaction.go | 37 -- util/file.go => file.go | 2 +- util/file_test.go => file_test.go | 2 +- util/form.go => form.go | 2 +- util/form_test.go => form_test.go | 2 +- util/helper.go => helper.go | 5 +- util/helper_test.go => helper_test.go | 2 +- util/ip.go => ip.go | 2 +- util/ip_test.go => ip_test.go | 2 +- logger/zap.go => logger.go | 53 +- mutex/distributed.go => mutex.go | 9 +- util/slice.go => slice.go | 13 +- util/slice_test.go => slice_test.go | 2 +- sql_builder.go | 889 ++++++++++++++++++++++++++ sql_builder_test.go | 408 ++++++++++++ util/string.go => string.go | 2 +- util/string_test.go => string_test.go | 2 +- util/time.go => time.go | 2 +- util/time_test.go => time_test.go | 2 +- type.go | 13 + websocket/dial.go | 8 - websocket/upgrader.go | 8 - 24 files changed, 1397 insertions(+), 133 deletions(-) rename db/db.go => db.go (53%) delete mode 100644 db/transaction.go rename util/file.go => file.go (98%) rename util/file_test.go => file_test.go (96%) rename util/form.go => form.go (99%) rename util/form_test.go => form_test.go (99%) rename util/helper.go => helper.go (96%) rename util/helper_test.go => helper_test.go (99%) rename util/ip.go => ip.go (96%) rename util/ip_test.go => ip_test.go (94%) rename logger/zap.go => logger.go (59%) rename mutex/distributed.go => mutex.go (87%) rename util/slice.go => slice.go (79%) rename util/slice_test.go => slice_test.go (98%) create mode 100644 sql_builder.go create mode 100644 sql_builder_test.go rename util/string.go => string.go (98%) rename util/string_test.go => string_test.go (96%) rename util/time.go => time.go (98%) rename util/time_test.go => time_test.go (97%) create mode 100644 type.go diff --git a/coord/location.go b/coord/location.go index d3858cf..eb50e92 100644 --- a/coord/location.go +++ b/coord/location.go @@ -77,21 +77,17 @@ func (l *Location) Azimuth(t *Location) float64 { sinc := math.Sqrt(1 - cosc*cosc) sinA := math.Sin(a) * math.Sin(AOC_BOC) / sinc - if sinA > 1 { sinA = 1 } - if sinA < -1 { sinA = -1 } angle := math.Asin(sinA) / math.Pi * 180 - if t.lat < l.lat { return 180 - angle } - if t.lng < l.lng { return 360 + angle } diff --git a/db/db.go b/db.go similarity index 53% rename from db/db.go rename to db.go index a570d4b..4704d01 100644 --- a/db/db.go +++ b/db.go @@ -1,7 +1,9 @@ -package db +package yiigo import ( + "context" "database/sql" + "fmt" "time" _ "github.com/go-sql-driver/mysql" @@ -10,8 +12,8 @@ import ( _ "github.com/mattn/go-sqlite3" ) -// Config 数据库初始化配置 -type Config struct { +// DBConfig 数据库初始化配置 +type DBConfig struct { // Driver 驱动名称 Driver string // DSN 数据源名称 @@ -19,12 +21,6 @@ type Config struct { // [Postgres] host=localhost port=5432 user=root password=secret dbname=test connect_timeout=10 sslmode=disable // [- SQLite] file::memory:?cache=shared DSN string - // Options 配置选项 - Options *Options -} - -// Options 数据库配置选项 -type Options struct { // MaxOpenConns 设置最大可打开的连接数 MaxOpenConns int // MaxIdleConns 连接池最大闲置连接数 @@ -35,7 +31,7 @@ type Options struct { ConnMaxIdleTime time.Duration } -func New(cfg *Config) (*sql.DB, error) { +func NewDB(cfg *DBConfig) (*sql.DB, error) { db, err := sql.Open(cfg.Driver, cfg.DSN) if err != nil { return nil, err @@ -45,21 +41,48 @@ func New(cfg *Config) (*sql.DB, error) { return nil, err } - if cfg.Options != nil { - db.SetMaxOpenConns(cfg.Options.MaxOpenConns) - db.SetMaxIdleConns(cfg.Options.MaxIdleConns) - db.SetConnMaxLifetime(cfg.Options.ConnMaxLifetime) - db.SetConnMaxIdleTime(cfg.Options.ConnMaxIdleTime) - } + db.SetMaxOpenConns(cfg.MaxOpenConns) + db.SetMaxIdleConns(cfg.MaxIdleConns) + db.SetConnMaxLifetime(cfg.ConnMaxLifetime) + db.SetConnMaxIdleTime(cfg.ConnMaxIdleTime) return db, nil } -func NewX(cfg *Config) (*sqlx.DB, error) { - db, err := New(cfg) +func NewDBX(cfg *DBConfig) (*sqlx.DB, error) { + db, err := NewDB(cfg) if err != nil { return nil, err } return sqlx.NewDb(db, cfg.Driver), nil } + +// Transaction 执行数据库事物 +func Transaction(ctx context.Context, db *sqlx.DB, f func(ctx context.Context, tx *sqlx.Tx) error) error { + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return err + } + + defer func() { + if v := recover(); v != nil { + tx.Rollback() + panic(v) + } + }() + + if err = f(ctx, tx); err != nil { + if rerr := tx.Rollback(); rerr != nil { + err = fmt.Errorf("%w: rolling back transaction: %v", err, rerr) + } + + return err + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} diff --git a/db/transaction.go b/db/transaction.go deleted file mode 100644 index dabe6af..0000000 --- a/db/transaction.go +++ /dev/null @@ -1,37 +0,0 @@ -package db - -import ( - "context" - "fmt" - - "github.com/jmoiron/sqlx" -) - -// Transaction 执行数据库事物 -func Transaction(ctx context.Context, db *sqlx.DB, f func(ctx context.Context, tx *sqlx.Tx) error) error { - tx, err := db.BeginTxx(ctx, nil) - if err != nil { - return err - } - - defer func() { - if v := recover(); v != nil { - tx.Rollback() - panic(v) - } - }() - - if err = f(ctx, tx); err != nil { - if rerr := tx.Rollback(); rerr != nil { - err = fmt.Errorf("%w: rolling back transaction: %v", err, rerr) - } - - return err - } - - if err = tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} diff --git a/util/file.go b/file.go similarity index 98% rename from util/file.go rename to file.go index c1ba08d..e6dc797 100644 --- a/util/file.go +++ b/file.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "os" diff --git a/util/file_test.go b/file_test.go similarity index 96% rename from util/file_test.go rename to file_test.go index 45586fe..35e7b07 100644 --- a/util/file_test.go +++ b/file_test.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "fmt" diff --git a/util/form.go b/form.go similarity index 99% rename from util/form.go rename to form.go index acf3574..df5296f 100644 --- a/util/form.go +++ b/form.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "encoding/json" diff --git a/util/form_test.go b/form_test.go similarity index 99% rename from util/form_test.go rename to form_test.go index 1ac5f70..0c8c04e 100644 --- a/util/form_test.go +++ b/form_test.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "reflect" diff --git a/util/helper.go b/helper.go similarity index 96% rename from util/helper.go rename to helper.go index 919e58a..f8c6146 100644 --- a/util/helper.go +++ b/helper.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "bytes" @@ -11,9 +11,6 @@ import ( "github.com/hashicorp/go-version" ) -// X 类型别名 -type X map[string]any - // Nonce 生成随机串(size应为偶数) func Nonce(size uint8) string { nonce := make([]byte, size/2) diff --git a/util/helper_test.go b/helper_test.go similarity index 99% rename from util/helper_test.go rename to helper_test.go index 876df8b..a016911 100644 --- a/util/helper_test.go +++ b/helper_test.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "testing" diff --git a/util/ip.go b/ip.go similarity index 96% rename from util/ip.go rename to ip.go index af633f5..e14c179 100644 --- a/util/ip.go +++ b/ip.go @@ -1,4 +1,4 @@ -package util +package yiigo import "net" diff --git a/util/ip_test.go b/ip_test.go similarity index 94% rename from util/ip_test.go rename to ip_test.go index 6c35564..90f2e9b 100644 --- a/util/ip_test.go +++ b/ip_test.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "testing" diff --git a/logger/zap.go b/logger.go similarity index 59% rename from logger/zap.go rename to logger.go index 90e0268..6d80c9a 100644 --- a/logger/zap.go +++ b/logger.go @@ -1,4 +1,4 @@ -package logger +package yiigo import ( "os" @@ -9,16 +9,10 @@ import ( "gopkg.in/natefinch/lumberjack.v2" ) -// Config 日志初始化配置 -type Config struct { +// LogConfig 日志初始化配置 +type LogConfig struct { // Filename 日志名称 Filename string - // Options 日志选项 - Options *Options -} - -// Options 日志配置选项 -type Options struct { // MaxSize 当前文件多大时轮替;默认:100MB MaxSize int // MaxAge 轮替的旧文件最大保留时长;默认:不限 @@ -29,11 +23,11 @@ type Options struct { Compress bool // Stderr 是否输出到控制台 Stderr bool - // ZapOpts Zap日志选项 - ZapOpts []zap.Option + // Options Zap日志选项 + Options []zap.Option } -func Debug(options ...zap.Option) *zap.Logger { +func DebugLogger(options ...zap.Option) *zap.Logger { cfg := zap.NewDevelopmentConfig() cfg.DisableCaller = true @@ -46,9 +40,9 @@ func Debug(options ...zap.Option) *zap.Logger { return logger } -func New(cfg *Config) *zap.Logger { +func NewLogger(cfg *LogConfig) *zap.Logger { if len(cfg.Filename) == 0 { - return Debug() + return DebugLogger(cfg.Options...) } ec := zap.NewProductionEncoderConfig() @@ -56,28 +50,21 @@ func New(cfg *Config) *zap.Logger { ec.EncodeTime = MyTimeEncoder ec.EncodeCaller = zapcore.FullCallerEncoder - var zapOpts []zap.Option - - w := &lumberjack.Logger{ - Filename: cfg.Filename, - LocalTime: true, + ws := []zapcore.WriteSyncer{ + zapcore.AddSync(&lumberjack.Logger{ + Filename: cfg.Filename, + MaxSize: cfg.MaxSize, + MaxAge: cfg.MaxAge, + MaxBackups: cfg.MaxBackups, + LocalTime: true, + Compress: cfg.Compress, + }), } - ws := make([]zapcore.WriteSyncer, 0, 2) - if cfg.Options != nil { - zapOpts = cfg.Options.ZapOpts - - w.MaxSize = cfg.Options.MaxSize - w.MaxAge = cfg.Options.MaxAge - w.MaxBackups = cfg.Options.MaxBackups - w.Compress = cfg.Options.Compress - - if cfg.Options.Stderr { - ws = append(ws, zapcore.Lock(os.Stderr)) - } + if cfg.Stderr { + ws = append(ws, zapcore.Lock(os.Stderr)) } - ws = append(ws, zapcore.AddSync(w)) - return zap.New(zapcore.NewCore(zapcore.NewJSONEncoder(ec), zapcore.NewMultiWriteSyncer(ws...), zap.InfoLevel), zapOpts...) + return zap.New(zapcore.NewCore(zapcore.NewJSONEncoder(ec), zapcore.NewMultiWriteSyncer(ws...), zap.InfoLevel), cfg.Options...) } // MyTimeEncoder 自定义时间格式化 diff --git a/mutex/distributed.go b/mutex.go similarity index 87% rename from mutex/distributed.go rename to mutex.go index eb20241..8b9863c 100644 --- a/mutex/distributed.go +++ b/mutex.go @@ -1,4 +1,4 @@ -package redis +package yiigo import ( "context" @@ -7,7 +7,7 @@ import ( "github.com/redis/go-redis/v9" ) -// Mutex 基于Redis实现的分布式锁 +// Mutex 分布式锁 type Mutex interface { // Lock 尝试获取锁;interval - 每隔指定时间尝试获取一次锁;timeout - 获取锁的超时时间 Lock(ctx context.Context, interval, timeout time.Duration) error @@ -15,6 +15,7 @@ type Mutex interface { UnLock(ctx context.Context) error } +// distributed 基于「Redis」实现的分布式锁 type distributed struct { cli *redis.Client key string @@ -58,8 +59,8 @@ func (d *distributed) UnLock(ctx context.Context) error { return d.cli.Del(ctx, d.key).Err() } -// DistributedMutex 返回一个分布式锁实例 -// uniqueID - 建议使用RequestID +// DistributedMutex 返回一个分布式锁实例; +// uniqueID - 建议使用「RequestID」 func DistributedMutex(cli *redis.Client, key, uniqueID string, expire time.Duration) Mutex { mutex := &distributed{ cli: cli, diff --git a/util/slice.go b/slice.go similarity index 79% rename from util/slice.go rename to slice.go index 01d2ed2..25945f3 100644 --- a/util/slice.go +++ b/slice.go @@ -1,15 +1,18 @@ -package util +package yiigo -import "math/rand" +import ( + "cmp" + "math/rand" +) // SliceUniq 切片去重 -func SliceUniq[T ~int | ~int64 | ~float64 | ~string](a []T) []T { - ret := make([]T, 0) +func SliceUniq[T ~[]E, E cmp.Ordered](a T) T { + ret := make(T, 0) if len(a) == 0 { return ret } - m := make(map[T]struct{}, 0) + m := make(map[E]struct{}, 0) for _, v := range a { if _, ok := m[v]; !ok { diff --git a/util/slice_test.go b/slice_test.go similarity index 98% rename from util/slice_test.go rename to slice_test.go index 3827b34..487a17e 100644 --- a/util/slice_test.go +++ b/slice_test.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "testing" diff --git a/sql_builder.go b/sql_builder.go new file mode 100644 index 0000000..ef239cd --- /dev/null +++ b/sql_builder.go @@ -0,0 +1,889 @@ +package yiigo + +import ( + "context" + "errors" + "reflect" + "strings" + + "github.com/jmoiron/sqlx" +) + +var ( + // ErrUpsertData 不合法的插入或更新数据类型错误 + ErrUpsertData = errors.New("invaild data, expects struct, *struct, yiigo.X") + + // ErrBatchInsertData 不合法的批量插入数据类型错误 + ErrBatchInsertData = errors.New("invaild data, expects []struct, []*struct, []yiigo.X") +) + +// SQLBuilder SQL构造器 +type SQLBuilder interface { + // Wrap 包装查询选项 + Wrap(options ...QueryOption) SQLWrapper +} + +// SQLWrapper SQL包装器 +type SQLWrapper interface { + // ToQuery 生成SELECT语句 + ToQuery(ctx context.Context) (sql string, args []any, err error) + // ToInsert 生成INSERT语句 + // 数据类型:`struct`, `*struct`, `yiigo.X`. + ToInsert(ctx context.Context, data any) (sql string, args []any, err error) + // ToBatchInsert 生成批量INSERT语句 + // 数据类型:`[]struct`, `[]*struct`, `[]yiigo.X`. + ToBatchInsert(ctx context.Context, data any) (sql string, args []any, err error) + // ToUpdate 生成UPDATE语句 + // 数据类型:`struct`, `*struct`, `yiigo.X`. + ToUpdate(ctx context.Context, data any) (sql string, args []any, err error) + // ToDelete 生成DELETE语句 + ToDelete(ctx context.Context) (sql string, args []any, err error) + // ToTruncate 生成TRUNCATE语句 + ToTruncate(ctx context.Context) string +} + +type queryBuilder struct { + driver DBDriver +} + +func (b *queryBuilder) Wrap(options ...QueryOption) SQLWrapper { + wrapper := &queryWrapper{ + builder: b, + columns: []string{"*"}, + } + + for _, f := range options { + f(wrapper) + } + + return wrapper +} + +// NewSQLBuilder 生成一个指定驱动类型的SQL构造器 +func NewSQLBuilder(driver DBDriver) SQLBuilder { + return &queryBuilder{ + driver: driver, + } +} + +// NewMySQLBuilder 生成一个MySQL构造器 +func NewMySQLBuilder() SQLBuilder { + return NewSQLBuilder(MySQL) +} + +// NewPGSQLBuilder 生成一个Postgres构造器 +func NewPGSQLBuilder() SQLBuilder { + return NewSQLBuilder(Postgres) +} + +// NewSQLiteBuilder 生成一个SQLite构造器 +func NewSQLiteBuilder() SQLBuilder { + return NewSQLBuilder(SQLite) +} + +// SQLClause SQL语句 +type SQLClause struct { + table string + keyword string + query string + binds []any +} + +// SQLExpr 生成一个语句表达式,例如:yiigo.SQLExpr("price * ? + ?", 2, 100) +func SQLExpr(query string, binds ...any) *SQLClause { + return &SQLClause{ + query: query, + binds: binds, + } +} + +type queryWrapper struct { + builder *queryBuilder + table string + columns []string + where *SQLClause + joins []*SQLClause + groups []string + having *SQLClause + orders []string + offset int + limit int + unions []*SQLClause + distinct bool + whereIn bool +} + +func (w *queryWrapper) ToQuery(ctx context.Context) (sql string, args []any, err error) { + sql, args = w.subquery() + + // unions + if l := len(w.unions); l != 0 { + var builder strings.Builder + + builder.WriteString("(") + builder.WriteString(sql) + builder.WriteString(")") + + for _, v := range w.unions { + builder.WriteString(" ") + builder.WriteString(v.keyword) + builder.WriteString(" (") + builder.WriteString(v.query) + builder.WriteString(")") + + args = append(args, v.binds...) + } + + sql = builder.String() + } + + // where in + if w.whereIn { + sql, args, err = sqlx.In(sql, args...) + if err != nil { + return + } + } + + sql = sqlx.Rebind(sqlx.BindType(string(w.builder.driver)), sql) + + return +} + +func (w *queryWrapper) subquery() (string, []any) { + binds := make([]any, 0) + + var builder strings.Builder + + builder.WriteString("SELECT ") + if w.distinct { + builder.WriteString("DISTINCT ") + } + + builder.WriteString(w.columns[0]) + + for _, column := range w.columns[1:] { + builder.WriteString(", ") + builder.WriteString(column) + } + + builder.WriteString(" FROM ") + builder.WriteString(w.table) + + if len(w.joins) != 0 { + for _, join := range w.joins { + builder.WriteString(" ") + builder.WriteString(join.keyword) + builder.WriteString(" JOIN ") + builder.WriteString(join.table) + + if len(join.query) != 0 { + builder.WriteString(" ON ") + builder.WriteString(join.query) + } + } + } + + if w.where != nil { + builder.WriteString(" WHERE ") + builder.WriteString(w.where.query) + + binds = append(binds, w.where.binds...) + } + + if len(w.groups) != 0 { + builder.WriteString(" GROUP BY ") + builder.WriteString(w.groups[0]) + + for _, column := range w.groups[1:] { + builder.WriteString(", ") + builder.WriteString(column) + } + } + + if w.having != nil { + builder.WriteString(" HAVING ") + builder.WriteString(w.having.query) + + binds = append(binds, w.having.binds...) + } + + if len(w.orders) != 0 { + builder.WriteString(" ORDER BY ") + builder.WriteString(w.orders[0]) + + for _, column := range w.orders[1:] { + builder.WriteString(", ") + builder.WriteString(column) + } + } + + if w.limit != 0 { + builder.WriteString(" LIMIT ?") + binds = append(binds, w.limit) + } + + if w.offset != 0 { + builder.WriteString(" OFFSET ?") + binds = append(binds, w.offset) + } + + return builder.String(), binds +} + +func (w *queryWrapper) ToInsert(ctx context.Context, data any) (sql string, args []any, err error) { + var columns []string + + v := reflect.Indirect(reflect.ValueOf(data)) + + switch v.Kind() { + case reflect.Map: + x, ok := data.(X) + if !ok { + err = ErrUpsertData + return + } + + columns, args = w.insertWithMap(x) + case reflect.Struct: + columns, args = w.insertWithStruct(v) + default: + err = ErrUpsertData + return + } + + var builder strings.Builder + + builder.WriteString("INSERT INTO ") + builder.WriteString(w.table) + + if l := len(columns); l != 0 { + builder.WriteString(" (") + builder.WriteString(columns[0]) + + for _, column := range columns[1:] { + builder.WriteString(", ") + builder.WriteString(column) + } + + builder.WriteString(") VALUES (?") + for i := 1; i < l; i++ { + builder.WriteString(", ?") + } + builder.WriteString(")") + } + + if w.builder.driver == Postgres { + builder.WriteString(" RETURNING id") + } + + sql = sqlx.Rebind(sqlx.BindType(string(w.builder.driver)), builder.String()) + + return +} + +func (w *queryWrapper) insertWithMap(data X) (columns []string, binds []any) { + fieldNum := len(data) + + columns = make([]string, 0, fieldNum) + binds = make([]any, 0, fieldNum) + + for k, v := range data { + columns = append(columns, k) + binds = append(binds, v) + } + + return +} + +func (w *queryWrapper) insertWithStruct(v reflect.Value) (columns []string, binds []any) { + fieldNum := v.NumField() + + columns = make([]string, 0, fieldNum) + binds = make([]any, 0, fieldNum) + + t := v.Type() + + for i := 0; i < fieldNum; i++ { + fieldT := t.Field(i) + + tag := fieldT.Tag.Get("db") + if tag == "-" { + continue + } + + fieldV := v.Field(i) + column := fieldT.Name + + if len(tag) != 0 { + name, opts := parseTag(tag) + if opts.Contains("omitempty") && isEmptyValue(fieldV) { + continue + } + + column = name + } + + columns = append(columns, column) + binds = append(binds, fieldV.Interface()) + } + + return +} + +func (w *queryWrapper) ToBatchInsert(ctx context.Context, data any) (sql string, args []any, err error) { + v := reflect.Indirect(reflect.ValueOf(data)) + + if v.Kind() != reflect.Slice { + err = ErrBatchInsertData + return + } + + if v.Len() == 0 { + err = errors.New("err empty data") + return + } + + var columns []string + + e := v.Type().Elem() + + switch e.Kind() { + case reflect.Map: + x, ok := data.([]X) + if !ok { + err = ErrBatchInsertData + return + } + + columns, args = w.batchInsertWithMap(x) + case reflect.Struct: + columns, args = w.batchInsertWithStruct(v) + case reflect.Ptr: + if e.Elem().Kind() != reflect.Struct { + err = ErrBatchInsertData + return + } + + columns, args = w.batchInsertWithStruct(v) + default: + err = ErrBatchInsertData + return + } + + var builder strings.Builder + + builder.WriteString("INSERT INTO ") + builder.WriteString(w.table) + + if l := len(columns); l != 0 { + builder.WriteString(" (") + builder.WriteString(columns[0]) + + for _, column := range columns[1:] { + builder.WriteString(", ") + builder.WriteString(column) + } + + // 首行 + builder.WriteString(") VALUES (?") + for i := 1; i < l; i++ { + builder.WriteString(", ?") + } + builder.WriteString(")") + + rows := len(args) / l + + // 其余行 + for i := 1; i < rows; i++ { + builder.WriteString(", (?") + for j := 1; j < l; j++ { + builder.WriteString(", ?") + } + builder.WriteString(")") + } + } + + sql = sqlx.Rebind(sqlx.BindType(string(w.builder.driver)), builder.String()) + + return +} + +func (w *queryWrapper) batchInsertWithMap(data []X) (columns []string, binds []any) { + dataLen := len(data) + fieldNum := len(data[0]) + + columns = make([]string, 0, fieldNum) + binds = make([]any, 0, fieldNum*dataLen) + + for k := range data[0] { + columns = append(columns, k) + } + + for _, x := range data { + for _, v := range columns { + binds = append(binds, x[v]) + } + } + + return +} + +func (w *queryWrapper) batchInsertWithStruct(v reflect.Value) (columns []string, binds []any) { + first := reflect.Indirect(v.Index(0)) + + dataLen := v.Len() + fieldNum := first.NumField() + + columns = make([]string, 0, fieldNum) + binds = make([]any, 0, fieldNum*dataLen) + + t := first.Type() + + for i := 0; i < dataLen; i++ { + for j := 0; j < fieldNum; j++ { + fieldT := t.Field(j) + + tag := fieldT.Tag.Get("db") + if tag == "-" { + continue + } + + fieldV := reflect.Indirect(v.Index(i)).Field(j) + column := fieldT.Name + + if len(tag) != 0 { + name, opts := parseTag(tag) + if opts.Contains("omitempty") && isEmptyValue(fieldV) { + continue + } + + column = name + } + + if i == 0 { + columns = append(columns, column) + } + + binds = append(binds, fieldV.Interface()) + } + } + + return +} + +func (w *queryWrapper) ToUpdate(ctx context.Context, data any) (sql string, args []any, err error) { + var ( + columns []string + exprs map[string]string + ) + + v := reflect.Indirect(reflect.ValueOf(data)) + + switch v.Kind() { + case reflect.Map: + x, ok := data.(X) + if !ok { + err = ErrUpsertData + return + } + + columns, exprs, args = w.updateWithMap(x) + case reflect.Struct: + columns, args = w.updateWithStruct(v) + default: + err = ErrUpsertData + return + } + + var builder strings.Builder + + builder.WriteString("UPDATE ") + builder.WriteString(w.table) + + if len(columns) != 0 { + builder.WriteString(" SET ") + builder.WriteString(columns[0]) + + if expr, ok := exprs[columns[0]]; ok { + builder.WriteString(" = ") + builder.WriteString(expr) + } else { + builder.WriteString(" = ?") + } + + for _, column := range columns[1:] { + builder.WriteString(", ") + builder.WriteString(column) + + if expr, ok := exprs[column]; ok { + builder.WriteString(" = ") + builder.WriteString(expr) + } else { + builder.WriteString(" = ?") + } + } + } + + if w.where != nil { + builder.WriteString(" WHERE ") + builder.WriteString(w.where.query) + + args = append(args, w.where.binds...) + } + + sql = builder.String() + + if w.whereIn { + sql, args, err = sqlx.In(sql, args...) + if err != nil { + return + } + } + + sql = sqlx.Rebind(sqlx.BindType(string(w.builder.driver)), sql) + + return +} + +func (w *queryWrapper) updateWithMap(data X) (columns []string, exprs map[string]string, binds []any) { + fieldNum := len(data) + + columns = make([]string, 0, fieldNum) + exprs = make(map[string]string) + binds = make([]any, 0, fieldNum) + + for k, v := range data { + columns = append(columns, k) + + if clause, ok := v.(*SQLClause); ok { + exprs[k] = clause.query + binds = append(binds, clause.binds...) + + continue + } + + binds = append(binds, v) + } + + return +} + +func (w *queryWrapper) updateWithStruct(v reflect.Value) (columns []string, binds []any) { + fieldNum := v.NumField() + + columns = make([]string, 0, fieldNum) + binds = make([]any, 0, fieldNum) + + t := v.Type() + + for i := 0; i < fieldNum; i++ { + fieldT := t.Field(i) + + tag := fieldT.Tag.Get("db") + if tag == "-" { + continue + } + + fieldV := v.Field(i) + column := fieldT.Name + + if len(tag) != 0 { + name, opts := parseTag(tag) + if opts.Contains("omitempty") && isEmptyValue(fieldV) { + continue + } + + column = name + } + + columns = append(columns, column) + binds = append(binds, fieldV.Interface()) + } + + return +} + +func (w *queryWrapper) ToDelete(ctx context.Context) (sql string, args []any, err error) { + var builder strings.Builder + + builder.WriteString("DELETE FROM ") + builder.WriteString(w.table) + + if w.where != nil { + builder.WriteString(" WHERE ") + builder.WriteString(w.where.query) + + args = append(args, w.where.binds...) + } + + sql = builder.String() + + if w.whereIn { + sql, args, err = sqlx.In(sql, args...) + if err != nil { + return + } + } + + sql = sqlx.Rebind(sqlx.BindType(string(w.builder.driver)), sql) + + return +} + +func (w *queryWrapper) ToTruncate(ctx context.Context) string { + var builder strings.Builder + + builder.WriteString("TRUNCATE ") + builder.WriteString(w.table) + + return builder.String() +} + +// QueryOption SQL查询选项 +type QueryOption func(w *queryWrapper) + +// Table 指定查询表名称 +func Table(name string) QueryOption { + return func(w *queryWrapper) { + w.table = name + } +} + +// Select 指定查询字段名 +func Select(columns ...string) QueryOption { + return func(w *queryWrapper) { + w.columns = columns + } +} + +// Distinct 指定 `DISTINCT` 语句 +func Distinct(columns ...string) QueryOption { + return func(w *queryWrapper) { + w.columns = columns + w.distinct = true + } +} + +// Join 指定 `INNER JOIN` 语句 +func Join(table, on string) QueryOption { + return func(w *queryWrapper) { + w.joins = append(w.joins, &SQLClause{ + table: table, + keyword: "INNER", + query: on, + }) + } +} + +// LeftJoin 指定 `LEFT JOIN` 语句 +func LeftJoin(table, on string) QueryOption { + return func(w *queryWrapper) { + w.joins = append(w.joins, &SQLClause{ + table: table, + keyword: "LEFT", + query: on, + }) + } +} + +// RightJoin 指定 `RIGHT JOIN` 语句 +func RightJoin(table, on string) QueryOption { + return func(w *queryWrapper) { + w.joins = append(w.joins, &SQLClause{ + table: table, + keyword: "RIGHT", + query: on, + }) + } +} + +// FullJoin 指定 `FULL JOIN` 语句 +func FullJoin(table, on string) QueryOption { + return func(w *queryWrapper) { + w.joins = append(w.joins, &SQLClause{ + table: table, + keyword: "FULL", + query: on, + }) + } +} + +// CrossJoin 指定 `CROSS JOIN` 语句 +func CrossJoin(table string) QueryOption { + return func(w *queryWrapper) { + w.joins = append(w.joins, &SQLClause{ + table: table, + keyword: "CROSS", + }) + } +} + +// Where 指定 `WHERE` 语句 +func Where(query string, binds ...any) QueryOption { + return func(w *queryWrapper) { + w.where = &SQLClause{ + query: query, + binds: binds, + } + } +} + +// WhereIn 指定 `WHERE IN` 语句 +func WhereIn(query string, binds ...any) QueryOption { + return func(w *queryWrapper) { + w.where = &SQLClause{ + query: query, + binds: binds, + } + + w.whereIn = true + } +} + +// GroupBy 指定 `GROUP BY` 语句 +func GroupBy(columns ...string) QueryOption { + return func(w *queryWrapper) { + w.groups = columns + } +} + +// Having 指定 `HAVING` 语句 +func Having(query string, binds ...any) QueryOption { + return func(w *queryWrapper) { + w.having = &SQLClause{ + query: query, + binds: binds, + } + } +} + +// OrderBy 指定 `ORDER BY` 语句 +func OrderBy(columns ...string) QueryOption { + return func(w *queryWrapper) { + w.orders = columns + } +} + +// Offset 指定 `OFFSET` 语句 +func Offset(n int) QueryOption { + return func(w *queryWrapper) { + w.offset = n + } +} + +// Limit 指定 `LIMIT` 语句 +func Limit(n int) QueryOption { + return func(w *queryWrapper) { + w.limit = n + } +} + +// Union 指定 `UNION` 语句 +func Union(wrappers ...SQLWrapper) QueryOption { + return func(w *queryWrapper) { + for _, wrapper := range wrappers { + v, ok := wrapper.(*queryWrapper) + if !ok { + continue + } + + if v.whereIn { + w.whereIn = true + } + + query, binds := v.subquery() + + w.unions = append(w.unions, &SQLClause{ + keyword: "UNION", + query: query, + binds: binds, + }) + } + } +} + +// UnionAll 指定 `UNION ALL` 语句 +func UnionAll(wrappers ...SQLWrapper) QueryOption { + return func(w *queryWrapper) { + for _, wrapper := range wrappers { + v, ok := wrapper.(*queryWrapper) + if !ok { + continue + } + + if v.whereIn { + w.whereIn = true + } + + query, binds := v.subquery() + + w.unions = append(w.unions, &SQLClause{ + keyword: "UNION ALL", + query: query, + binds: binds, + }) + } + } +} + +// tagOptions is the string following a comma in a struct field's "json" +// tag, or the empty string. It does not include the leading comma. +type tagOptions string + +// Contains reports whether a comma-separated list of options +// contains a particular substr flag. substr must be surrounded by a +// string boundary or commas. +func (o tagOptions) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + + s := string(o) + + for len(s) != 0 { + var next string + + i := strings.Index(s, ",") + if i >= 0 { + s, next = s[:i], s[i+1:] + } + + if s == optionName { + return true + } + + s = next + } + + return false +} + +// parseTag splits a struct field's json tag into its name and +// comma-separated options. +func parseTag(tag string) (string, tagOptions) { + if idx := strings.Index(tag, ","); idx != -1 { + return tag[:idx], tagOptions(tag[idx+1:]) + } + + return tag, tagOptions("") +} + +func isEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + + return false +} diff --git a/sql_builder_test.go b/sql_builder_test.go new file mode 100644 index 0000000..1e3f8d4 --- /dev/null +++ b/sql_builder_test.go @@ -0,0 +1,408 @@ +package yiigo + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestToQuery(t *testing.T) { + ctx := context.TODO() + + builder := NewMySQLBuilder() + + sql, args, err := builder.Wrap( + Table("user"), + Where("id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user WHERE id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + Where("name = ? AND age > ?", "yiigo", 20), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user WHERE name = ? AND age > ?", sql) + assert.Equal(t, []any{"yiigo", 20}, args) + + sql, args, err = builder.Wrap( + Table("user"), + WhereIn("age IN (?)", []int{20, 30}), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user WHERE age IN (?, ?)", sql) + assert.Equal(t, []any{20, 30}, args) + + sql, args, err = builder.Wrap( + Table("user"), + Select("id", "name", "age"), + Where("id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT id, name, age FROM user WHERE id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + Distinct("name"), + Where("id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT DISTINCT name FROM user WHERE id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + Join("address", "user.id = address.user_id"), + Where("user.id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user INNER JOIN address ON user.id = address.user_id WHERE user.id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + LeftJoin("address", "user.id = address.user_id"), + Where("user.id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user LEFT JOIN address ON user.id = address.user_id WHERE user.id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + RightJoin("address", "user.id = address.user_id"), + Where("user.id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user RIGHT JOIN address ON user.id = address.user_id WHERE user.id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + FullJoin("address", "user.id = address.user_id"), + Where("user.id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user FULL JOIN address ON user.id = address.user_id WHERE user.id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, _, err = builder.Wrap( + Table("sizes"), + CrossJoin("colors"), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM sizes CROSS JOIN colors", sql) + + sql, args, err = builder.Wrap( + Table("user"), + LeftJoin("address", "user.id = address.user_id"), + RightJoin("company", "user.id = company.user_id"), + Where("user.id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user LEFT JOIN address ON user.id = address.user_id RIGHT JOIN company ON user.id = company.user_id WHERE user.id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("address"), + Select("user_id", "COUNT(*) AS total"), + GroupBy("user_id"), + Having("user_id = ?", 1), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT user_id, COUNT(*) AS total FROM address GROUP BY user_id HAVING user_id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + Where("age > ?", 20), + OrderBy("age ASC", "id DESC"), + Offset(5), + Limit(10), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "SELECT * FROM user WHERE age > ? ORDER BY age ASC, id DESC LIMIT ? OFFSET ?", sql) + assert.Equal(t, []any{20, 10, 5}, args) + + sql, args, err = builder.Wrap( + Table("user_0"), + Where("id = ?", 1), + Union(builder.Wrap(Table("user_1"), Where("id = ?", 2))), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "(SELECT * FROM user_0 WHERE id = ?) UNION (SELECT * FROM user_1 WHERE id = ?)", sql) + assert.Equal(t, []any{1, 2}, args) + + sql, args, err = builder.Wrap( + Table("user_0"), + Where("id = ?", 1), + UnionAll(builder.Wrap(Table("user_1"), Where("id = ?", 2))), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "(SELECT * FROM user_0 WHERE id = ?) UNION ALL (SELECT * FROM user_1 WHERE id = ?)", sql) + assert.Equal(t, []any{1, 2}, args) + + sql, args, err = builder.Wrap( + Table("user_0"), + WhereIn("age IN (?)", []int{10, 20}), + Limit(5), + Union( + builder.Wrap( + Table("user_1"), + WhereIn("age IN (?)", []int{30, 40}), + Limit(5), + ), + ), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "(SELECT * FROM user_0 WHERE age IN (?, ?) LIMIT ?) UNION (SELECT * FROM user_1 WHERE age IN (?, ?) LIMIT ?)", sql) + assert.Equal(t, []any{10, 20, 5, 30, 40, 5}, args) + + sql, args, err = builder.Wrap( + Table("user_0"), + Where("id = ?", 1), + Union(builder.Wrap(Table("user_1"), Where("id = ?", 2))), + UnionAll(builder.Wrap(Table("user_2"), Where("id = ?", 3))), + ).ToQuery(ctx) + + assert.Nil(t, err) + assert.Equal(t, "(SELECT * FROM user_0 WHERE id = ?) UNION (SELECT * FROM user_1 WHERE id = ?) UNION ALL (SELECT * FROM user_2 WHERE id = ?)", sql) + assert.Equal(t, []any{1, 2, 3}, args) +} + +func TestToInsert(t *testing.T) { + ctx := context.TODO() + + builder := NewMySQLBuilder() + + type User struct { + ID int `db:"-"` + Name string `db:"name"` + Gender string `db:"gender"` + Age int `db:"age"` + Phone string `db:"phone,omitempty"` + } + + sql, args, err := builder.Wrap(Table("user")).ToInsert(ctx, &User{ + Name: "yiigo", + Gender: "M", + Age: 29, + }) + + assert.Nil(t, err) + assert.Equal(t, "INSERT INTO user (name, gender, age) VALUES (?, ?, ?)", sql) + assert.Equal(t, []any{"yiigo", "M", 29}, args) + + sql, args, err = builder.Wrap(Table("user")).ToInsert(ctx, &User{ + Name: "yiigo", + Gender: "M", + Age: 29, + Phone: "13605109425", + }) + + assert.Nil(t, err) + assert.Equal(t, "INSERT INTO user (name, gender, age, phone) VALUES (?, ?, ?, ?)", sql) + assert.Equal(t, []any{"yiigo", "M", 29, "13605109425"}, args) + + // map 字段顺序不一定 + // sql, args, err = builder.Wrap(Table("user")).ToInsert(ctx, X{ + // "age": 29, + // "gender": "M", + // "name": "yiigo", + // }) + // + // assert.Equal(t, "INSERT INTO user (age, gender, name) VALUES (?, ?, ?)", sql) + // assert.Equal(t, []any{29, "M", "yiigo"}, args) +} + +func TestToBatchInsert(t *testing.T) { + ctx := context.TODO() + + builder := NewMySQLBuilder() + + type User struct { + ID int `db:"-"` + Name string `db:"name"` + Gender string `db:"gender"` + Age int `db:"age"` + Phone string `db:"phone,omitempty"` + } + + sql, args, err := builder.Wrap(Table("user")).ToBatchInsert(ctx, []*User{ + { + Name: "yiigo", + Gender: "M", + Age: 29, + }, + { + Name: "test", + Gender: "W", + Age: 20, + }, + }) + + assert.Nil(t, err) + assert.Equal(t, "INSERT INTO user (name, gender, age) VALUES (?, ?, ?), (?, ?, ?)", sql) + assert.Equal(t, []any{"yiigo", "M", 29, "test", "W", 20}, args) + + sql, args, err = builder.Wrap(Table("user")).ToBatchInsert(ctx, []*User{ + { + Name: "yiigo", + Gender: "M", + Age: 29, + Phone: "13605109425", + }, + { + Name: "test", + Gender: "W", + Age: 20, + Phone: "13605105471", + }, + }) + + assert.Nil(t, err) + assert.Equal(t, "INSERT INTO user (name, gender, age, phone) VALUES (?, ?, ?, ?), (?, ?, ?, ?)", sql) + assert.Equal(t, []any{"yiigo", "M", 29, "13605109425", "test", "W", 20, "13605105471"}, args) + + // map 字段顺序不一定 + // sql, args, err = builder.Wrap(Table("user")).ToBatchInsert(ctx, []X{ + // { + // "age": 29, + // "gender": "M", + // "name": "yiigo", + // }, + // { + // "age": 20, + // "gender": "W", + // "name": "test", + // }, + // }) + // + // assert.Equal(t, "INSERT INTO user (age, gender, name) VALUES (?, ?, ?), (?, ?, ?)", sql) + // assert.Equal(t, []any{29, "M", "yiigo", 20, "W", "test"}, args) +} + +func TestToUpdate(t *testing.T) { + ctx := context.TODO() + + builder := NewMySQLBuilder() + + type User struct { + Name string `db:"name"` + Gender string `db:"gender"` + Age int `db:"age"` + Phone string `db:"phone,omitempty"` + } + + sql, args, err := builder.Wrap( + Table("user"), + Where("id = ?", 1), + ).ToUpdate(ctx, &User{ + Name: "yiigo", + Gender: "M", + Age: 29, + }) + + assert.Nil(t, err) + assert.Equal(t, "UPDATE user SET name = ?, gender = ?, age = ? WHERE id = ?", sql) + assert.Equal(t, []any{"yiigo", "M", 29, 1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + Where("id = ?", 1), + ).ToUpdate(ctx, &User{ + Name: "yiigo", + Gender: "M", + Age: 29, + Phone: "13605109425", + }) + + assert.Nil(t, err) + assert.Equal(t, "UPDATE user SET name = ?, gender = ?, age = ?, phone = ? WHERE id = ?", sql) + assert.Equal(t, []any{"yiigo", "M", 29, "13605109425", 1}, args) + + // map 字段顺序不一定 + // sql, args, err = builder.Wrap( + // Table("user"), + // Where("id = ?", 1), + // ).ToUpdate(ctx, X{ + // "age": 29, + // "gender": "M", + // "name": "yiigo", + // }) + // + // assert.Equal(t, "UPDATE user SET age = ?, gender = ?, name = ? WHERE id = ?", sql) + // assert.Equal(t, []any{29, "M", "yiigo", 1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + WhereIn("id IN (?)", []int{1, 2}), + ).ToUpdate(ctx, &User{ + Name: "yiigo", + Gender: "M", + Age: 29, + }) + + assert.Nil(t, err) + assert.Equal(t, "UPDATE user SET name = ?, gender = ?, age = ? WHERE id IN (?, ?)", sql) + assert.Equal(t, []any{"yiigo", "M", 29, 1, 2}, args) + + sql, args, err = builder.Wrap( + Table("product"), + Where("id = ?", 1), + ).ToUpdate(ctx, X{"price": SQLExpr("price * ? + ?", 2, 100)}) + + assert.Nil(t, err) + assert.Equal(t, "UPDATE product SET price = price * ? + ? WHERE id = ?", sql) + assert.Equal(t, []any{2, 100, 1}, args) +} + +func TestToDelete(t *testing.T) { + ctx := context.TODO() + + builder := NewMySQLBuilder() + + sql, args, err := builder.Wrap( + Table("user"), + Where("id = ?", 1), + ).ToDelete(ctx) + + assert.Nil(t, err) + assert.Equal(t, "DELETE FROM user WHERE id = ?", sql) + assert.Equal(t, []any{1}, args) + + sql, args, err = builder.Wrap( + Table("user"), + WhereIn("id IN (?)", []int{1, 2}), + ).ToDelete(ctx) + + assert.Nil(t, err) + assert.Equal(t, "DELETE FROM user WHERE id IN (?, ?)", sql) + assert.Equal(t, []any{1, 2}, args) +} + +func TestToTruncate(t *testing.T) { + builder := NewMySQLBuilder() + + assert.Equal(t, "TRUNCATE user", builder.Wrap(Table("user")).ToTruncate(context.TODO())) +} diff --git a/util/string.go b/string.go similarity index 98% rename from util/string.go rename to string.go index 44a4a75..a8f1979 100644 --- a/util/string.go +++ b/string.go @@ -1,4 +1,4 @@ -package util +package yiigo import "strings" diff --git a/util/string_test.go b/string_test.go similarity index 96% rename from util/string_test.go rename to string_test.go index 3323859..8c423b8 100644 --- a/util/string_test.go +++ b/string_test.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "testing" diff --git a/util/time.go b/time.go similarity index 98% rename from util/time.go rename to time.go index 1685eab..bf37d5f 100644 --- a/util/time.go +++ b/time.go @@ -1,4 +1,4 @@ -package util +package yiigo import "time" diff --git a/util/time_test.go b/time_test.go similarity index 97% rename from util/time_test.go rename to time_test.go index a761ffd..3fab466 100644 --- a/util/time_test.go +++ b/time_test.go @@ -1,4 +1,4 @@ -package util +package yiigo import ( "testing" diff --git a/type.go b/type.go new file mode 100644 index 0000000..0bb7947 --- /dev/null +++ b/type.go @@ -0,0 +1,13 @@ +package yiigo + +// X 类型别名 +type X map[string]any + +// DBDriver 数据库驱动 +type DBDriver string + +const ( + MySQL DBDriver = "mysql" + Postgres DBDriver = "pgx" + SQLite DBDriver = "sqlite3" +) diff --git a/websocket/dial.go b/websocket/dial.go index f5feebd..5accf76 100644 --- a/websocket/dial.go +++ b/websocket/dial.go @@ -3,9 +3,7 @@ package websocket import ( "context" "fmt" - "log" "net/http" - "runtime/debug" "time" "github.com/gorilla/websocket" @@ -72,12 +70,6 @@ func (c *DialConn) reconnect() error { // Read 读消息,若失败会尝试重连 (reconnectTimeout<=0 表示重连不超时) func (c *DialConn) Read(reconnectTimeout time.Duration, handler func(msg *Message)) error { - defer func() { - if err := recover(); err != nil { - log.Printf("websocket read panic, error: %v, stack: %s\n", err, string(debug.Stack())) - } - }() - for { t, b, err := c.conn.ReadMessage() if err == nil { diff --git a/websocket/upgrader.go b/websocket/upgrader.go index f6a77b0..e6b9198 100644 --- a/websocket/upgrader.go +++ b/websocket/upgrader.go @@ -3,9 +3,7 @@ package websocket import ( "context" "errors" - "log" "net/http" - "runtime/debug" "github.com/gorilla/websocket" ) @@ -33,12 +31,6 @@ type UpgradeConn struct { // Read 读消息 func (c *UpgradeConn) Read(ctx context.Context, handler func(ctx context.Context, msg *Message) (*Message, error)) error { - defer func() { - if err := recover(); err != nil { - log.Printf("websocket read panic, error: %v, stack: %s\n", err, string(debug.Stack())) - } - }() - for { select { case <-ctx.Done():