From cbc7e9569032dba0886c5341391d723367fffe78 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 May 2024 10:48:18 -0500 Subject: [PATCH] stdlib matches native pgx scanning support stdlib can now directly scan into anything pgx can scan such as Go slices. This requires the change to database/sql implemented by https://github.com/golang/go/pull/67648. If this PR is accepted it will most likely land in Go 1.24. --- stdlib/bench_test.go | 51 +++++++++++++ stdlib/sql.go | 6 ++ stdlib/sql_test.go | 166 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 212 insertions(+), 11 deletions(-) diff --git a/stdlib/bench_test.go b/stdlib/bench_test.go index 33734fc1b..141fc4eb5 100644 --- a/stdlib/bench_test.go +++ b/stdlib/bench_test.go @@ -8,6 +8,8 @@ import ( "strings" "testing" "time" + + "github.com/jackc/pgx/v5/pgtype" ) func getSelectRowsCounts(b *testing.B) []int64 { @@ -107,3 +109,52 @@ func BenchmarkSelectRowsScanNull(b *testing.B) { }) } } + +func BenchmarkFlatArrayEncodeArgument(b *testing.B) { + db := openDB(b) + defer closeDB(b, db) + + input := make(pgtype.FlatArray[string], 10) + for i := range input { + input[i] = fmt.Sprintf("String %d", i) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var n int64 + err := db.QueryRow("select cardinality($1::text[])", input).Scan(&n) + if err != nil { + b.Fatal(err) + } + if n != int64(len(input)) { + b.Fatalf("Expected %d, got %d", len(input), n) + } + } +} + +func BenchmarkFlatArrayScanResult(b *testing.B) { + db := openDB(b) + defer closeDB(b, db) + + var input string + for i := 0; i < 10; i++ { + if i > 0 { + input += "," + } + input += fmt.Sprintf(`'String %d'`, i) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var result pgtype.FlatArray[string] + err := db.QueryRow(fmt.Sprintf("select array[%s]::text[]", input)).Scan(&result) + if err != nil { + b.Fatal(err) + } + if len(result) != 10 { + b.Fatalf("Expected %d, got %d", len(result), 10) + } + } +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 29cd3fbbf..885cae324 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -847,6 +847,12 @@ func (r *Rows) Next(dest []driver.Value) error { return nil } +func (r *Rows) ScanColumn(index int, dest any) error { + m := r.conn.conn.TypeMap() + fd := r.rows.FieldDescriptions()[index] + return m.Scan(fd.DataTypeOID, fd.Format, r.rows.RawValues()[index], dest) +} + func valueToInterface(argsV []driver.Value) []any { args := make([]any, 0, len(argsV)) for _, v := range argsV { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 1e3db0d35..92ff84564 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -107,6 +107,32 @@ func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { } } +func testWithKnownOIDQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { + for _, mode := range []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + } { + t.Run(mode.String(), + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.DefaultQueryExecMode = mode + db := stdlib.OpenDB(*config) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + f(t, db) + + ensureDBValid(t, db) + }, + ) + } +} + // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should // cover broken connections. func ensureDBValid(t testing.TB, db *sql.DB) { @@ -509,29 +535,99 @@ func TestConnQueryScanGoArray(t *testing.T) { }) } -func TestConnQueryScanArray(t *testing.T) { +func TestGoArray(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - m := pgtype.NewMap() + var names []string - var a pgtype.Array[int64] - err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names) require.NoError(t, err) - assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a) + require.Equal(t, []string{"John", "Jane"}, names) - err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a)) + var n int + err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n) require.NoError(t, err) - assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a) + require.EqualValues(t, 2, n) + + err = db.QueryRow("select null::text[]").Scan(&names) + require.NoError(t, err) + require.Nil(t, names) }) } -func TestConnQueryScanRange(t *testing.T) { +func TestGoArrayOfDriverValuer(t *testing.T) { + // Because []sql.NullString is not a registered type on the connection, it will only work with known OIDs. + testWithKnownOIDQueryExecModes(t, func(t *testing.T, db *sql.DB) { + var names []sql.NullString + + err := db.QueryRow("select array['John', null, 'Jane']::text[]").Scan(&names) + require.NoError(t, err) + require.Equal(t, []sql.NullString{{String: "John", Valid: true}, {}, {String: "Jane", Valid: true}}, names) + + var n int + err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) + + err = db.QueryRow("select null::text[]").Scan(&names) + require.NoError(t, err) + require.Nil(t, names) + }) +} + +func TestPGTypeFlatArray(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - skipCockroachDB(t, db, "Server does not support int4range") + var names pgtype.FlatArray[string] - m := pgtype.NewMap() + err := db.QueryRow("select array['John', 'Jane']::text[]").Scan(&names) + require.NoError(t, err) + require.Equal(t, pgtype.FlatArray[string]{"John", "Jane"}, names) + + var n int + err = db.QueryRow("select cardinality($1::text[])", names).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) + + err = db.QueryRow("select null::text[]").Scan(&names) + require.NoError(t, err) + require.Nil(t, names) + }) +} + +func TestPGTypeArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support nested arrays") + + var matrix pgtype.Array[int64] + + err := db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[]").Scan(&matrix) + require.NoError(t, err) + require.Equal(t, + pgtype.Array[int64]{ + Elements: []int64{1, 2, 3, 4, 5, 6}, + Dims: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + }, + Valid: true}, + matrix) + + var equal bool + err = db.QueryRow("select '{{1,2,3},{4,5,6}}'::bigint[] = $1::bigint[]", matrix).Scan(&equal) + require.NoError(t, err) + require.Equal(t, true, equal) + + err = db.QueryRow("select null::bigint[]").Scan(&matrix) + require.NoError(t, err) + assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, matrix) + }) +} + +func TestConnQueryPGTypeRange(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support int4range") var r pgtype.Range[pgtype.Int4] - err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r)) + err := db.QueryRow("select int4range(1, 5)").Scan(&r) require.NoError(t, err) assert.Equal( t, @@ -543,6 +639,54 @@ func TestConnQueryScanRange(t *testing.T) { Valid: true, }, r) + + var equal bool + err = db.QueryRow("select int4range(1, 5) = $1::int4range", r).Scan(&equal) + require.NoError(t, err) + require.Equal(t, true, equal) + + err = db.QueryRow("select null::int4range").Scan(&r) + require.NoError(t, err) + assert.Equal(t, pgtype.Range[pgtype.Int4]{}, r) + }) +} + +func TestConnQueryPGTypeMultirange(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + skipCockroachDB(t, db, "Server does not support int4range") + skipPostgreSQLVersionLessThan(t, db, 14) + + var r pgtype.Multirange[pgtype.Range[pgtype.Int4]] + err := db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9))").Scan(&r) + require.NoError(t, err) + assert.Equal( + t, + pgtype.Multirange[pgtype.Range[pgtype.Int4]]{ + { + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + { + Lower: pgtype.Int4{Int32: 7, Valid: true}, + Upper: pgtype.Int4{Int32: 9, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + r) + + var equal bool + err = db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9)) = $1::int4multirange", r).Scan(&equal) + require.NoError(t, err) + require.Equal(t, true, equal) + + err = db.QueryRow("select null::int4multirange").Scan(&r) + require.NoError(t, err) + require.Nil(t, r) }) }