Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add DecodeOption for returning protobuf values #341

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,30 @@ func init() {
sql.Register("spanner", &Driver{connectors: make(map[string]*connector)})
}

// ExecOptions can be passed in as an argument to the Query, QueryContext,
// Exec, and ExecContext functions to specify additional execution options
// for a statement.
type ExecOptions struct {
// DecodeOption indicates how the returned rows should be decoded.
DecodeOption DecodeOption
}

type DecodeOption int

const (
// DecodeOptionNormal decodes into idiomatic Go types (e.g. bool, string, int64, etc.)
DecodeOptionNormal DecodeOption = iota

// DecodeOptionProto does not decode the returned rows at all, and instead just returns
// the underlying protobuf objects. Use this for advanced use-cases where you want
// direct access to the underlying values.
// All values should be scanned into an instance of spanner.GenericColumnValue like this:
//
// var v spanner.GenericColumnValue
// row.Scan(&v)
DecodeOptionProto
)

// Driver represents a Google Cloud Spanner database/sql driver.
type Driver struct {
mu sync.Mutex
Expand Down Expand Up @@ -652,9 +676,13 @@ type conn struct {
autocommitDMLMode AutocommitDMLMode
// readOnlyStaleness is used for queries in autocommit mode and for read-only transactions.
readOnlyStaleness spanner.TimestampBound
// excludeTxnFromChangeStreams is used to exlude the next transaction from change streams with the DDL option
// excludeTxnFromChangeStreams is used to exclude the next transaction from change streams with the DDL option
// `allow_txn_exclusion=true`
excludeTxnFromChangeStreams bool
// execOptions are applied to the next statement that is executed on this connection.
// It can only be set by passing it in as an argument to ExecContext or QueryContext
// and is cleared after each execution.
execOptions ExecOptions
}

type batchType int
Expand Down Expand Up @@ -1077,6 +1105,12 @@ func (c *conn) CheckNamedValue(value *driver.NamedValue) error {
if value == nil {
return nil
}

if execOptions, ok := value.Value.(ExecOptions); ok {
c.execOptions = execOptions
return driver.ErrRemoveArgument
}

if checkIsValidType(value.Value) {
return nil
}
Expand Down Expand Up @@ -1113,6 +1147,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
execOptions := c.options()

// Execute client side statement if it is one.
clientStmt, err := parseClientSideStatement(c, query)
if err != nil {
Expand All @@ -1134,10 +1170,13 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
} else {
iter = c.tx.Query(ctx, stmt)
}
return &rows{it: iter}, nil
return &rows{it: iter, decodeOption: execOptions.DecodeOption}, nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// Make sure options are reset after calling this method.
_ = c.options()

// Execute client side statement if it is one.
stmt, err := parseClientSideStatement(c, query)
if err != nil {
Expand Down Expand Up @@ -1195,6 +1234,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return &result{rowsAffected: rowsAffected}, nil
}

// options returns and resets the ExecOptions for the next statement.
func (c *conn) options() ExecOptions {
defer func() { c.execOptions = ExecOptions{} }()
return c.execOptions
}

func (c *conn) Close() error {
return c.connector.decreaseConnCount()
}
Expand Down
97 changes: 97 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,103 @@ func TestQueryWithNullParameters(t *testing.T) {
}
}

func TestQueryWithAllTypes_ReturnProto(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()
query := "SELECT * FROM Test"
_ = server.TestSpanner.PutStatementResult(
query,
&testutil.StatementResult{
Type: testutil.StatementResultResultSet,
ResultSet: testutil.CreateResultSetWithAllTypes(false),
},
)

for _, prepare := range []bool{false, true} {
var rows *sql.Rows
if prepare {
stmt, err := db.Prepare(query)
if err != nil {
t.Fatal(err)
}
rows, err = stmt.QueryContext(context.Background(), ExecOptions{DecodeOption: DecodeOptionProto})
if err != nil {
t.Fatal(err)
}
stmt.Close()
} else {
var err error
rows, err = db.QueryContext(context.Background(), query, ExecOptions{DecodeOption: DecodeOptionProto})
if err != nil {
t.Fatal(err)
}
}

for rows.Next() {
var b spanner.GenericColumnValue
var s spanner.GenericColumnValue
var bt spanner.GenericColumnValue
var i spanner.GenericColumnValue
var f32 spanner.GenericColumnValue
var f spanner.GenericColumnValue
var r spanner.GenericColumnValue
var d spanner.GenericColumnValue
var ts spanner.GenericColumnValue
var j spanner.GenericColumnValue
var bArray spanner.GenericColumnValue
var sArray spanner.GenericColumnValue
var btArray spanner.GenericColumnValue
var iArray spanner.GenericColumnValue
var f32Array spanner.GenericColumnValue
var fArray spanner.GenericColumnValue
var rArray spanner.GenericColumnValue
var dArray spanner.GenericColumnValue
var tsArray spanner.GenericColumnValue
var jArray spanner.GenericColumnValue
err := rows.Scan(&b, &s, &bt, &i, &f32, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &f32Array, &fArray, &rArray, &dArray, &tsArray, &jArray)
if err != nil {
t.Fatal(err)
}
if g, w := b.Value.GetBoolValue(), true; g != w {
t.Errorf("row value mismatch for bool\nGot: %v\nWant: %v", g, w)
}
if g, w := s.Value.GetStringValue(), "test"; g != w {
t.Errorf("row value mismatch for string\nGot: %v\nWant: %v", g, w)
}
if g, w := bt.Value.GetStringValue(), base64.RawURLEncoding.EncodeToString([]byte("testbytes")); !cmp.Equal(g, w) {
t.Errorf("row value mismatch for bytes\nGot: %v\nWant: %v", g, w)
}
if g, w := i.Value.GetStringValue(), "5"; g != w {
t.Errorf("row value mismatch for int64\nGot: %v\nWant: %v", g, w)
}
if g, w := float32(f32.Value.GetNumberValue()), float32(3.14); g != w {
t.Errorf("row value mismatch for float32\nGot: %v\nWant: %v", g, w)
}
if g, w := f.Value.GetNumberValue(), 3.14; g != w {
t.Errorf("row value mismatch for float64\nGot: %v\nWant: %v", g, w)
}
if g, w := r.Value.GetStringValue(), "6.626"; g != w {
t.Errorf("row value mismatch for numeric\nGot: %v\nWant: %v", g, w)
}
if g, w := d.Value.GetStringValue(), "2021-07-21"; g != w {
t.Errorf("row value mismatch for date\nGot: %v\nWant: %v", g, w)
}
if g, w := ts.Value.GetStringValue(), "2021-07-21T21:07:59.339911800Z"; g != w {
t.Errorf("row value mismatch for timestamp\nGot: %v\nWant: %v", g, w)
}
if g, w := j.Value.GetStringValue(), `{"key": "value", "other-key": ["value1", "value2"]}`; g != w {
t.Errorf("row value mismatch for json\n Got: %v\nWant: %v", g, w)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
rows.Close()
}
}

func TestDmlInAutocommit(t *testing.T) {
t.Parallel()

Expand Down
69 changes: 69 additions & 0 deletions examples/decode-options/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"database/sql"
"fmt"

"cloud.google.com/go/spanner"
spannerdriver "github.com/googleapis/go-sql-spanner"
"github.com/googleapis/go-sql-spanner/examples"
)

// Example for getting the underlying protobuf objects from a query result
// instead of decoding the values into Go types. This can be used for
// advanced use cases where you want to have full control over how data
// is decoded, or where you want to skip the decode step for performance
// reasons.
func decodeOptions(projectId, instanceId, databaseId string) error {
ctx := context.Background()
db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId))
if err != nil {
return fmt.Errorf("failed to open database connection: %v", err)
}
defer db.Close()

// Pass an ExecOptions value with DecodeOption set to DecodeOptionProto
// as an argument to QueryContext to instruct the Spanner driver to skip
// decoding the data into Go types.
rows, err := db.QueryContext(ctx,
`SELECT JSON '{"key1": "value1", "key2": 2, "key3": ["value1", "value2"]}'`,
spannerdriver.ExecOptions{DecodeOption: spannerdriver.DecodeOptionProto})
if err != nil {
return fmt.Errorf("failed to execute query: %v", err)
}
defer rows.Close()

for rows.Next() {
// As we are using DecodeOptionProto, all values must be scanned
// into spanner.GenericColumnValue.
var value spanner.GenericColumnValue

if err := rows.Scan(&value); err != nil {
return fmt.Errorf("failed to scan row values: %v", err)
}
fmt.Printf("Received value %v\n", value.Value.GetStringValue())
}
if err := rows.Err(); err != nil {
return fmt.Errorf("failed to execute query: %v", err)
}
return nil
}

func main() {
examples.RunSampleOnEmulator(decodeOptions)
}
11 changes: 8 additions & 3 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import (
type rows struct {
it rowIterator

colsOnce sync.Once
dirtyErr error
cols []string
colsOnce sync.Once
dirtyErr error
cols []string
decodeOption DecodeOption

dirtyRow *spanner.Row
}
Expand Down Expand Up @@ -107,6 +108,10 @@ func (r *rows) Next(dest []driver.Value) error {
if err := row.Column(i, &col); err != nil {
return err
}
if r.decodeOption == DecodeOptionProto {
dest[i] = col
continue
}
switch col.Type.Code {
case sppb.TypeCode_INT64:
var v spanner.NullInt64
Expand Down
21 changes: 17 additions & 4 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import (
)

type stmt struct {
conn *conn
numArgs int
query string
conn *conn
numArgs int
query string
execOptions ExecOptions
}

func (s *stmt) Close() error {
Expand Down Expand Up @@ -61,7 +62,19 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
} else {
it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).Query(ctx, ss)}
}
return &rows{it: it}, nil
return &rows{it: it, decodeOption: s.execOptions.DecodeOption}, nil
}

func (s *stmt) CheckNamedValue(value *driver.NamedValue) error {
if value == nil {
return nil
}

if execOptions, ok := value.Value.(ExecOptions); ok {
s.execOptions = execOptions
return driver.ErrRemoveArgument
}
return nil
}

func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement, error) {
Expand Down
Loading