diff --git a/driver.go b/driver.go index c851b4af..c8536992 100644 --- a/driver.go +++ b/driver.go @@ -78,6 +78,30 @@ func init() { sql.Register("spanner", &Driver{connectors: make(map[string]*connector)}) } +// ExecOptions can be passed in as an argument to the Query, QueryContext, +// Exec, and ExecContext functions to specify additional execution options +// for a statement. +type ExecOptions struct { + // DecodeOption indicates how the returned rows should be decoded. + DecodeOption DecodeOption +} + +type DecodeOption int + +const ( + // DecodeOptionNormal decodes into idiomatic Go types (e.g. bool, string, int64, etc.) + DecodeOptionNormal DecodeOption = iota + + // DecodeOptionProto does not decode the returned rows at all, and instead just returns + // the underlying protobuf objects. Use this for advanced use-cases where you want + // direct access to the underlying values. + // All values should be scanned into an instance of spanner.GenericColumnValue like this: + // + // var v spanner.GenericColumnValue + // row.Scan(&v) + DecodeOptionProto +) + // Driver represents a Google Cloud Spanner database/sql driver. type Driver struct { mu sync.Mutex @@ -652,9 +676,13 @@ type conn struct { autocommitDMLMode AutocommitDMLMode // readOnlyStaleness is used for queries in autocommit mode and for read-only transactions. readOnlyStaleness spanner.TimestampBound - // excludeTxnFromChangeStreams is used to exlude the next transaction from change streams with the DDL option + // excludeTxnFromChangeStreams is used to exclude the next transaction from change streams with the DDL option // `allow_txn_exclusion=true` excludeTxnFromChangeStreams bool + // execOptions are applied to the next statement that is executed on this connection. + // It can only be set by passing it in as an argument to ExecContext or QueryContext + // and is cleared after each execution. + execOptions ExecOptions } type batchType int @@ -1077,6 +1105,12 @@ func (c *conn) CheckNamedValue(value *driver.NamedValue) error { if value == nil { return nil } + + if execOptions, ok := value.Value.(ExecOptions); ok { + c.execOptions = execOptions + return driver.ErrRemoveArgument + } + if checkIsValidType(value.Value) { return nil } @@ -1113,6 +1147,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + execOptions := c.options() + // Execute client side statement if it is one. clientStmt, err := parseClientSideStatement(c, query) if err != nil { @@ -1134,10 +1170,13 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } else { iter = c.tx.Query(ctx, stmt) } - return &rows{it: iter}, nil + return &rows{it: iter, decodeOption: execOptions.DecodeOption}, nil } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + // Make sure options are reset after calling this method. + _ = c.options() + // Execute client side statement if it is one. stmt, err := parseClientSideStatement(c, query) if err != nil { @@ -1195,6 +1234,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name return &result{rowsAffected: rowsAffected}, nil } +// options returns and resets the ExecOptions for the next statement. +func (c *conn) options() ExecOptions { + defer func() { c.execOptions = ExecOptions{} }() + return c.execOptions +} + func (c *conn) Close() error { return c.connector.decreaseConnCount() } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 1817c426..4e565186 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -1062,6 +1062,103 @@ func TestQueryWithNullParameters(t *testing.T) { } } +func TestQueryWithAllTypes_ReturnProto(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + query := "SELECT * FROM Test" + _ = server.TestSpanner.PutStatementResult( + query, + &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: testutil.CreateResultSetWithAllTypes(false), + }, + ) + + for _, prepare := range []bool{false, true} { + var rows *sql.Rows + if prepare { + stmt, err := db.Prepare(query) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.QueryContext(context.Background(), ExecOptions{DecodeOption: DecodeOptionProto}) + if err != nil { + t.Fatal(err) + } + stmt.Close() + } else { + var err error + rows, err = db.QueryContext(context.Background(), query, ExecOptions{DecodeOption: DecodeOptionProto}) + if err != nil { + t.Fatal(err) + } + } + + for rows.Next() { + var b spanner.GenericColumnValue + var s spanner.GenericColumnValue + var bt spanner.GenericColumnValue + var i spanner.GenericColumnValue + var f32 spanner.GenericColumnValue + var f spanner.GenericColumnValue + var r spanner.GenericColumnValue + var d spanner.GenericColumnValue + var ts spanner.GenericColumnValue + var j spanner.GenericColumnValue + var bArray spanner.GenericColumnValue + var sArray spanner.GenericColumnValue + var btArray spanner.GenericColumnValue + var iArray spanner.GenericColumnValue + var f32Array spanner.GenericColumnValue + var fArray spanner.GenericColumnValue + var rArray spanner.GenericColumnValue + var dArray spanner.GenericColumnValue + var tsArray spanner.GenericColumnValue + var jArray spanner.GenericColumnValue + err := rows.Scan(&b, &s, &bt, &i, &f32, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &f32Array, &fArray, &rArray, &dArray, &tsArray, &jArray) + if err != nil { + t.Fatal(err) + } + if g, w := b.Value.GetBoolValue(), true; g != w { + t.Errorf("row value mismatch for bool\nGot: %v\nWant: %v", g, w) + } + if g, w := s.Value.GetStringValue(), "test"; g != w { + t.Errorf("row value mismatch for string\nGot: %v\nWant: %v", g, w) + } + if g, w := bt.Value.GetStringValue(), base64.RawURLEncoding.EncodeToString([]byte("testbytes")); !cmp.Equal(g, w) { + t.Errorf("row value mismatch for bytes\nGot: %v\nWant: %v", g, w) + } + if g, w := i.Value.GetStringValue(), "5"; g != w { + t.Errorf("row value mismatch for int64\nGot: %v\nWant: %v", g, w) + } + if g, w := float32(f32.Value.GetNumberValue()), float32(3.14); g != w { + t.Errorf("row value mismatch for float32\nGot: %v\nWant: %v", g, w) + } + if g, w := f.Value.GetNumberValue(), 3.14; g != w { + t.Errorf("row value mismatch for float64\nGot: %v\nWant: %v", g, w) + } + if g, w := r.Value.GetStringValue(), "6.626"; g != w { + t.Errorf("row value mismatch for numeric\nGot: %v\nWant: %v", g, w) + } + if g, w := d.Value.GetStringValue(), "2021-07-21"; g != w { + t.Errorf("row value mismatch for date\nGot: %v\nWant: %v", g, w) + } + if g, w := ts.Value.GetStringValue(), "2021-07-21T21:07:59.339911800Z"; g != w { + t.Errorf("row value mismatch for timestamp\nGot: %v\nWant: %v", g, w) + } + if g, w := j.Value.GetStringValue(), `{"key": "value", "other-key": ["value1", "value2"]}`; g != w { + t.Errorf("row value mismatch for json\n Got: %v\nWant: %v", g, w) + } + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + rows.Close() + } +} + func TestDmlInAutocommit(t *testing.T) { t.Parallel() diff --git a/examples/decode-options/main.go b/examples/decode-options/main.go new file mode 100644 index 00000000..d4713210 --- /dev/null +++ b/examples/decode-options/main.go @@ -0,0 +1,69 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "database/sql" + "fmt" + + "cloud.google.com/go/spanner" + spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/examples" +) + +// Example for getting the underlying protobuf objects from a query result +// instead of decoding the values into Go types. This can be used for +// advanced use cases where you want to have full control over how data +// is decoded, or where you want to skip the decode step for performance +// reasons. +func decodeOptions(projectId, instanceId, databaseId string) error { + ctx := context.Background() + db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId)) + if err != nil { + return fmt.Errorf("failed to open database connection: %v", err) + } + defer db.Close() + + // Pass an ExecOptions value with DecodeOption set to DecodeOptionProto + // as an argument to QueryContext to instruct the Spanner driver to skip + // decoding the data into Go types. + rows, err := db.QueryContext(ctx, + `SELECT JSON '{"key1": "value1", "key2": 2, "key3": ["value1", "value2"]}'`, + spannerdriver.ExecOptions{DecodeOption: spannerdriver.DecodeOptionProto}) + if err != nil { + return fmt.Errorf("failed to execute query: %v", err) + } + defer rows.Close() + + for rows.Next() { + // As we are using DecodeOptionProto, all values must be scanned + // into spanner.GenericColumnValue. + var value spanner.GenericColumnValue + + if err := rows.Scan(&value); err != nil { + return fmt.Errorf("failed to scan row values: %v", err) + } + fmt.Printf("Received value %v\n", value.Value.GetStringValue()) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to execute query: %v", err) + } + return nil +} + +func main() { + examples.RunSampleOnEmulator(decodeOptions) +} diff --git a/rows.go b/rows.go index d652caff..0ed3a796 100644 --- a/rows.go +++ b/rows.go @@ -28,9 +28,10 @@ import ( type rows struct { it rowIterator - colsOnce sync.Once - dirtyErr error - cols []string + colsOnce sync.Once + dirtyErr error + cols []string + decodeOption DecodeOption dirtyRow *spanner.Row } @@ -107,6 +108,10 @@ func (r *rows) Next(dest []driver.Value) error { if err := row.Column(i, &col); err != nil { return err } + if r.decodeOption == DecodeOptionProto { + dest[i] = col + continue + } switch col.Type.Code { case sppb.TypeCode_INT64: var v spanner.NullInt64 diff --git a/stmt.go b/stmt.go index f4e7881c..18872da0 100644 --- a/stmt.go +++ b/stmt.go @@ -24,9 +24,10 @@ import ( ) type stmt struct { - conn *conn - numArgs int - query string + conn *conn + numArgs int + query string + execOptions ExecOptions } func (s *stmt) Close() error { @@ -61,7 +62,19 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv } else { it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).Query(ctx, ss)} } - return &rows{it: it}, nil + return &rows{it: it, decodeOption: s.execOptions.DecodeOption}, nil +} + +func (s *stmt) CheckNamedValue(value *driver.NamedValue) error { + if value == nil { + return nil + } + + if execOptions, ok := value.Value.(ExecOptions); ok { + s.execOptions = execOptions + return driver.ErrRemoveArgument + } + return nil } func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement, error) {