Skip to content

Commit

Permalink
feat: expose underlying *spanner.Client
Browse files Browse the repository at this point in the history
Sometimes the exposed features of go-sql-spanner are not sufficient to
get the necessary response. This adds an escape hatch to access the
underlying client.

Fixes googleapis#312
  • Loading branch information
egonelbre committed Nov 6, 2024
1 parent c39f57f commit a0960b1
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
18 changes: 16 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ type SpannerConn interface {
// was executed on the connection, or an error if the connection has not executed a read/write transaction
// that committed successfully. The timestamp is in the local timezone.
CommitTimestamp() (commitTimestamp time.Time, err error)

// UnderlyingClient returns the underlying client to the database.
// This should not be used together with transactions and batching.
UnderlyingClient() (client *spanner.Client, err error)
}

type conn struct {
Expand Down Expand Up @@ -542,6 +546,16 @@ const (
PartitionedNonAtomic
)

func (c *conn) UnderlyingClient() (*spanner.Client, error) {
if c.inTransaction() {
return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "cannot access *spanner.Client when in a transaction"))
}
if c.batch != nil {
return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "cannot access *spanner.Client with an active batch"))
}
return c.client, nil
}

func (c *conn) CommitTimestamp() (time.Time, error) {
if c.commitTs == nil {
return time.Time{}, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "this connection has not executed a read/write transaction that committed successfully"))
Expand Down Expand Up @@ -1150,8 +1164,8 @@ var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// to those types to mean nil/NULL, just like the Go database/sql package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
Expand Down
71 changes: 71 additions & 0 deletions examples/underlying-client/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"database/sql"
"fmt"

"cloud.google.com/go/spanner"
_ "github.com/googleapis/go-sql-spanner"
spannerdriver "github.com/googleapis/go-sql-spanner"
"github.com/googleapis/go-sql-spanner/examples"
)

// Example of using the underlying *spanner.Client.
func underlyingClient(projectId, instanceId, databaseId string) error {
ctx := context.Background()
db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId))
if err != nil {
return fmt.Errorf("failed to open database connection: %v\n", err)
}
defer db.Close()

conn, err := db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()

if err := conn.Raw(func(driverConn any) error {
spannerConn, ok := driverConn.(spannerdriver.SpannerConn)
if !ok {
return fmt.Errorf("unexpected driver connection %v, expected SpannerConn", driverConn)
}
client, err := spannerConn.UnderlyingClient()
if err != nil {
return fmt.Errorf("unable to access underlying client: %w", err)
}

row := client.Single().Query(ctx, spanner.Statement{SQL: "SELECT 1"})
return row.Do(func(r *spanner.Row) error {
var value int64
err := r.Columns(&value)
if err != nil {
return fmt.Errorf("failed to read column: %w", err)
}
fmt.Println(value)
return nil
})
}); err != nil {
return err
}
return nil
}

func main() {
examples.RunSampleOnEmulator(underlyingClient)
}

0 comments on commit a0960b1

Please sign in to comment.