Skip to content

Commit

Permalink
added mssql support (#223)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexey Tyuryumov <Alexey.Tyuryumov@acronis.com>
  • Loading branch information
Alexey19 and Alexey Tyuryumov authored Nov 16, 2020
1 parent 5ee9f28 commit 44bd707
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 31 deletions.
8 changes: 6 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ jobs:
environment:
- DBR_TEST_MYSQL_DSN=root:root@tcp(127.0.0.1:3306)/circle_test
- DBR_TEST_POSTGRES_DSN=postgres://postgres:mysecretpassword@127.0.0.1:5432/postgres?sslmode=disable
- DBR_TEST_MSSQL_DSN=sqlserver://sa:qwe123QWE@127.0.0.1/?database=master
- GO111MODULE=on
- image: percona:5.7
environment:
Expand All @@ -14,7 +15,10 @@ jobs:
- image: postgres:11-alpine
environment:
- POSTGRES_PASSWORD=mysecretpassword

- image: mcr.microsoft.com/mssql/server:2019-latest
environment:
- ACCEPT_EULA=Y
- SA_PASSWORD=qwe123QWE
working_directory: /go/src/github.com/gocraft/dbr
steps:
- checkout
Expand All @@ -25,5 +29,5 @@ jobs:
DOCKERIZE_VERSION: v0.3.0
- run:
name: Wait for db
command: dockerize -wait tcp://127.0.0.1:3306 -wait tcp://127.0.0.1:5432 -timeout 1m
command: dockerize -wait tcp://127.0.0.1:3306 -wait tcp://127.0.0.1:5432 -wait tcp://127.0.0.1:1433 -timeout 1m
- run: go test -v -cover -bench . ./...
2 changes: 2 additions & 0 deletions dbr.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ func Open(driver, dsn string, log EventReceiver) (*Connection, error) {
d = dialect.PostgreSQL
case "sqlite3":
d = dialect.SQLite3
case "mssql":
d = dialect.MSSQL
default:
return nil, ErrNotSupported
}
Expand Down
28 changes: 19 additions & 9 deletions dbr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/gocraft/dbr/v2/dialect"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/denisenkom/go-mssqldb"
"github.com/stretchr/testify/require"
)

Expand All @@ -22,6 +23,7 @@ var (
mysqlDSN = os.Getenv("DBR_TEST_MYSQL_DSN")
postgresDSN = os.Getenv("DBR_TEST_POSTGRES_DSN")
sqlite3DSN = ":memory:"
mssqlDSN = os.Getenv("DBR_TEST_MSSQL_DSN")
)

func createSession(driver, dsn string) *Session {
Expand All @@ -37,9 +39,10 @@ var (
postgresSession = createSession("postgres", postgresDSN)
postgresBinarySession = createSession("postgres", postgresDSN+"&binary_parameters=yes")
sqlite3Session = createSession("sqlite3", sqlite3DSN)
mssqlSession = createSession("mssql", mssqlDSN)

// all test sessions should be here
testSession = []*Session{mysqlSession, postgresSession, sqlite3Session}
testSession = []*Session{mysqlSession, postgresSession, sqlite3Session, mssqlSession}
)

type dbrPerson struct {
Expand All @@ -58,14 +61,17 @@ type nullTypedRecord struct {
}

func reset(t *testing.T, sess *Session) {
var autoIncrementType string
autoIncrementType := "serial PRIMARY KEY"
boolType := "bool"
datetimeType := "timestamp"

switch sess.Dialect {
case dialect.MySQL:
autoIncrementType = "serial PRIMARY KEY"
case dialect.PostgreSQL:
autoIncrementType = "serial PRIMARY KEY"
case dialect.SQLite3:
autoIncrementType = "integer PRIMARY KEY"
case dialect.MSSQL:
autoIncrementType = "integer IDENTITY PRIMARY KEY"
boolType = "BIT"
datetimeType = "datetime"
}
for _, v := range []string{
`DROP TABLE IF EXISTS dbr_people`,
Expand All @@ -81,9 +87,9 @@ func reset(t *testing.T, sess *Session) {
string_val varchar(255) NULL,
int64_val integer NULL,
float64_val float NULL,
time_val timestamp NULL ,
bool_val bool NULL
)`, autoIncrementType),
time_val %s NULL,
bool_val %s NULL
)`, autoIncrementType, datetimeType, boolType),
} {
_, err := sess.Exec(v)
require.NoError(t, err)
Expand All @@ -105,6 +111,10 @@ func TestBasicCRUD(t *testing.T) {
jonathan.Id = 1
insertColumns = []string{"id", "name", "email"}
}
if sess.Dialect == dialect.MSSQL {
jonathan.Id = 1
}

// insert
result, err := sess.InsertInto("dbr_people").Columns(insertColumns...).Record(&jonathan).Exec()
require.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ var (
PostgreSQL = postgreSQL{}
// SQLite3 dialect
SQLite3 = sqlite3{}
// MSSQL dialect
MSSQL = mssql{}
)

const (
Expand Down
18 changes: 18 additions & 0 deletions dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,21 @@ func TestSQLite3(t *testing.T) {
require.Equal(t, test.want, SQLite3.QuoteIdent(test.in))
}
}

func TestMSSQL(t *testing.T) {
for _, test := range []struct {
in string
want string
}{
{
in: "table.col",
want: `"table"."col"`,
},
{
in: "col",
want: `"col"`,
},
} {
require.Equal(t, test.want, MSSQL.QuoteIdent(test.in))
}
}
36 changes: 36 additions & 0 deletions dialect/mssql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package dialect

import (
"fmt"
"strings"
"time"
)

type mssql struct{}

func (d mssql) QuoteIdent(s string) string {
return quoteIdent(s, `"`)
}

func (d mssql) EncodeString(s string) string {
return `'` + strings.Replace(s, `'`, `''`, -1) + `'`
}

func (d mssql) EncodeBool(b bool) string {
if b {
return "1"
}
return "0"
}

func (d mssql) EncodeTime(t time.Time) string {
return t.Format("'2006-01-02 15:04:05.999'")
}

func (d mssql) EncodeBytes(b []byte) string {
return fmt.Sprintf(`E'\\x%x'`, b)
}

func (d mssql) Placeholder(n int) string {
return fmt.Sprintf("@p%d", n+1)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.13

require (
github.com/DATA-DOG/go-sqlmock v1.4.1
github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204
github.com/go-sql-driver/mysql v1.5.0
github.com/jmoiron/sqlx v1.2.0
github.com/lib/pq v1.3.0
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1M
github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204 h1:tI48fqaIkxxYuIylVv1tdDfBp6836GKSfmmzgSyP1CY=
github.com/denisenkom/go-mssqldb v0.0.0-20200910202707-1e08a3fab204/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
Expand All @@ -22,6 +26,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
Expand Down
20 changes: 17 additions & 3 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"database/sql"
"reflect"
"strings"

"github.com/gocraft/dbr/v2/dialect"
)

// InsertStmt builds `INSERT INTO ...`.
Expand Down Expand Up @@ -63,7 +65,19 @@ func (b *InsertStmt) Build(d Dialect, buf Buffer) error {
buf.WriteString(d.QuoteIdent(col))
placeholderBuf.WriteString(placeholder)
}
buf.WriteString(") VALUES ")
buf.WriteString(")")

if d == dialect.MSSQL && len(b.ReturnColumn) > 0 {
buf.WriteString(" OUTPUT ")
for i, col := range b.ReturnColumn {
if i > 0 {
buf.WriteString(",")
}
buf.WriteString("INSERTED." + d.QuoteIdent(col))
}
}

buf.WriteString(" VALUES ")
placeholderBuf.WriteString(")")
placeholderStr := placeholderBuf.String()

Expand All @@ -76,7 +90,7 @@ func (b *InsertStmt) Build(d Dialect, buf Buffer) error {
buf.WriteValue(tuple...)
}

if len(b.ReturnColumn) > 0 {
if d != dialect.MSSQL && len(b.ReturnColumn) > 0 {
buf.WriteString(" RETURNING ")
for i, col := range b.ReturnColumn {
if i > 0 {
Expand Down Expand Up @@ -199,7 +213,7 @@ func (b *InsertStmt) Record(structValue interface{}) *InsertStmt {
return b
}

// Returning specifies the returning columns for postgres.
// Returning specifies the returning columns for postgres/mssql.
func (b *InsertStmt) Returning(column ...string) *InsertStmt {
b.ReturnColumn = column
return b
Expand Down
56 changes: 49 additions & 7 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"strconv"

"github.com/gocraft/dbr/v2/dialect"
)

// SelectStmt builds `SELECT ...`.
Expand Down Expand Up @@ -130,14 +132,18 @@ func (b *SelectStmt) Build(d Dialect, buf Buffer) error {
}
}

if b.LimitCount >= 0 {
buf.WriteString(" LIMIT ")
buf.WriteString(strconv.FormatInt(b.LimitCount, 10))
}
if d == dialect.MSSQL {
b.addMSSQLLimits(buf)
} else {
if b.LimitCount >= 0 {
buf.WriteString(" LIMIT ")
buf.WriteString(strconv.FormatInt(b.LimitCount, 10))
}

if b.OffsetCount >= 0 {
buf.WriteString(" OFFSET ")
buf.WriteString(strconv.FormatInt(b.OffsetCount, 10))
if b.OffsetCount >= 0 {
buf.WriteString(" OFFSET ")
buf.WriteString(strconv.FormatInt(b.OffsetCount, 10))
}
}

if len(b.Suffixes) > 0 {
Expand All @@ -153,6 +159,42 @@ func (b *SelectStmt) Build(d Dialect, buf Buffer) error {
return nil
}

// https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2012/ms188385(v=sql.110)
func (b *SelectStmt) addMSSQLLimits(buf Buffer) {
limitCount := b.LimitCount
offsetCount := b.OffsetCount
if limitCount < 0 && offsetCount < 0 {
return
}
if offsetCount < 0 {
offsetCount = 0
}

if len(b.Order) == 0 {
// ORDER is required for OFFSET / FETCH
buf.WriteString(" ORDER BY ")
col := b.Column[0]
switch col := col.(type) {
case string:
// FIXME: no quote ident
buf.WriteString(col)
default:
buf.WriteString(placeholder)
buf.WriteValue(col)
}
}

buf.WriteString(" OFFSET ")
buf.WriteString(strconv.FormatInt(offsetCount, 10))
buf.WriteString(" ROWS ")

if limitCount >= 0 {
buf.WriteString(" FETCH FIRST ")
buf.WriteString(strconv.FormatInt(limitCount, 10))
buf.WriteString(" ROWS ONLY ")
}
}

// Select creates a SelectStmt.
func Select(column ...interface{}) *SelectStmt {
return &SelectStmt{
Expand Down
27 changes: 18 additions & 9 deletions transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dbr
import (
"testing"

"github.com/gocraft/dbr/v2/dialect"
"github.com/stretchr/testify/require"
)

Expand All @@ -14,17 +15,22 @@ func TestTransactionCommit(t *testing.T) {
require.NoError(t, err)
defer tx.RollbackUnlessCommitted()

id := 1
elem_count := 1
if sess.Dialect == dialect.MSSQL {
tx.UpdateBySql("SET IDENTITY_INSERT dbr_people ON;").Exec()
elem_count += 1
}

id := 1
result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Comment("INSERT TEST").Exec()
require.NoError(t, err)
require.Len(t, sess.EventReceiver.(*testTraceReceiver).started, 1)
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].eventName, "dbr.exec")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "/* INSERT TEST */\n")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "INSERT")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "dbr_people")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[0].query, "name")
require.Equal(t, 1, sess.EventReceiver.(*testTraceReceiver).finished)
require.Len(t, sess.EventReceiver.(*testTraceReceiver).started, elem_count)
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[elem_count-1].eventName, "dbr.exec")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[elem_count-1].query, "/* INSERT TEST */\n")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[elem_count-1].query, "INSERT")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[elem_count-1].query, "dbr_people")
require.Contains(t, sess.EventReceiver.(*testTraceReceiver).started[elem_count-1].query, "name")
require.Equal(t, elem_count, sess.EventReceiver.(*testTraceReceiver).finished)
require.Equal(t, 0, sess.EventReceiver.(*testTraceReceiver).errored)

rowsAffected, err := result.RowsAffected()
Expand All @@ -49,8 +55,11 @@ func TestTransactionRollback(t *testing.T) {
require.NoError(t, err)
defer tx.RollbackUnlessCommitted()

id := 1
if sess.Dialect == dialect.MSSQL {
tx.UpdateBySql("SET IDENTITY_INSERT dbr_people ON;").Exec()
}

id := 1
result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec()
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 44bd707

Please sign in to comment.