Skip to content

Commit

Permalink
Fix INSERT support (#5)
Browse files Browse the repository at this point in the history
INSERT statements work now. The PR also removes some warnings, adds
tests and comments.
A few libraries are updated to resolve security notices.

Testing Done: newly added integration test
  • Loading branch information
murfffi authored Dec 29, 2024
1 parent 1b8ebac commit 38e8f84
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 80 deletions.
25 changes: 23 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package impala
import (
"context"
"database/sql/driver"
"fmt"
"log"
"time"

Expand All @@ -19,7 +20,10 @@ type Conn struct {
}

// Ping impala server
// Implements driver.Pinger
func (c *Conn) Ping(ctx context.Context) error {
// TODO (Github #4) report ErrBadConn when appropriate

session, err := c.OpenSession(ctx)
if err != nil {
return err
Expand All @@ -45,19 +49,22 @@ func (c *Conn) CheckNamedValue(val *driver.NamedValue) error {
}

// Prepare returns prepared statement
// Implements driver.Conn
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}

// PrepareContext returns prepared statement
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
// Implements driver.ConnPrepareContext
func (c *Conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return &Stmt{
conn: c,
stmt: template(query),
}, nil
}

// QueryContext executes a query that may return rows
// Implements driver.QueryerContext
func (c *Conn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
session, err := c.OpenSession(ctx)
if err != nil {
Expand All @@ -70,6 +77,7 @@ func (c *Conn) QueryContext(ctx context.Context, q string, args []driver.NamedVa
}

// ExecContext executes a query that doesn't return rows
// Implements driver.ExecerContext
func (c *Conn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
session, err := c.OpenSession(ctx)
if err != nil {
Expand All @@ -82,6 +90,7 @@ func (c *Conn) ExecContext(ctx context.Context, q string, args []driver.NamedVal
}

// Begin is not supported
// Implements driver.Conn
func (c *Conn) Begin() (driver.Tx, error) {
return nil, ErrNotSupported
}
Expand All @@ -100,6 +109,7 @@ func (c *Conn) OpenSession(ctx context.Context) (*hive.Session, error) {
}

// ResetSession closes hive session
// Implements driver.SessionResetter
func (c *Conn) ResetSession(ctx context.Context) error {
if c.session != nil {
if err := c.session.Close(ctx); err != nil {
Expand All @@ -111,7 +121,18 @@ func (c *Conn) ResetSession(ctx context.Context) error {
}

// Close connection
// Implements driver.Conn
func (c *Conn) Close() error {
c.log.Printf("close connection")
return c.t.Close()
if c.session != nil {
err := c.session.Close(context.Background())
if err != nil {
return fmt.Errorf("failed to close underlying session while closing connection: %w", err)
}
}

if err := c.t.Close(); err != nil {
return fmt.Errorf("failed to close underlying transport while closing connection: %w", err)
}
return nil
}
24 changes: 23 additions & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestIntegration(t *testing.T) {
dsn := os.Getenv("IMPALA_DSN")
if dsn == "" {
ctx := context.Background()
t.Log("No IMPALA DSN environment variable set, starting Impala container ...")
t.Log("No IMPALA_DSN environment variable set, starting Impala container ...")
c := fi.NoError(Setup(ctx)).Require(t)
defer fi.NoErrorF(fi.Bind(c.Terminate, ctx), t)
dsn = GetDsn(ctx, t, c)
Expand All @@ -44,6 +44,9 @@ func TestIntegration(t *testing.T) {
t.Run("Metadata", func(t *testing.T) {
testMetadata(t, conn)
})
t.Run("Insert", func(t *testing.T) {
testInsert(t, conn)
})
}

func testPinger(t *testing.T, conn *sql.DB) {
Expand Down Expand Up @@ -93,6 +96,25 @@ func testMetadata(t *testing.T, conn *sql.DB) {
}))
}

func testInsert(t *testing.T, conn *sql.DB) {
var err error
_, err = conn.Exec("DROP TABLE IF EXISTS test")
require.NoError(t, err)
_, err = conn.Exec("CREATE TABLE if not exists test(a int)")
require.NoError(t, err)
insertRes, err := conn.Exec("INSERT INTO test (a) VALUES (1)")
require.NoError(t, err)
_, err = insertRes.RowsAffected()
require.Error(t, err) // not supported yet, see todo in statement.go/exec
selectRes, err := conn.Query("SELECT * FROM test WHERE a = 1 LIMIT 1")
require.NoError(t, err)
defer fi.NoErrorF(selectRes.Close, t)
require.True(t, selectRes.Next())
var val int
require.NoError(t, selectRes.Scan(&val))
require.Equal(t, val, 1)
}

func open(t *testing.T, dsn string) *sql.DB {
db, err := sql.Open("impala", dsn)
if err != nil {
Expand Down
22 changes: 18 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"strings"

"github.com/apache/thrift/lib/go/thrift"
"github.com/samber/lo"
"github.com/sclgo/impala-go/hive"
"github.com/sclgo/impala-go/sasl"
)
Expand Down Expand Up @@ -81,9 +82,9 @@ func parseURI(uri string) (*Options, error) {
opts.UseLDAP = true
}

tls, ok := query["tls"]
useTls, ok := query["tls"]
if ok {
v, err := strconv.ParseBool(tls[0])
v, err := strconv.ParseBool(useTls[0])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -127,6 +128,13 @@ func parseURI(uri string) (*Options, error) {
opts.QueryTimeout = qTimeout
}

logDest, ok := query["log"]
if ok {
if strings.ToLower(logDest[0]) == "stderr" {
opts.LogOut = os.Stderr
}
}

return &opts, nil
}

Expand All @@ -151,10 +159,13 @@ func NewConnector(opts *Options) driver.Connector {
return &connector{opts: opts}
}

func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
// Connect implements driver.Connector
func (c *connector) Connect(context.Context) (driver.Conn, error) {
// Strangely, TTransport.Open doesn't support context, so we don't use it here
return connect(c.opts)
}

// Driver implements driver.Connector
func (c *connector) Driver() driver.Driver {
return c.d
}
Expand Down Expand Up @@ -213,7 +224,10 @@ func connect(opts *Options) (*Conn, error) {
transport = thrift.NewTBufferedTransport(socket, opts.BufferSize)
}

protocol := thrift.NewTBinaryProtocol(transport, false, true)
protocol := thrift.NewTBinaryProtocolConf(transport, &thrift.TConfiguration{
TBinaryStrictRead: lo.ToPtr(false),
TBinaryStrictWrite: lo.ToPtr(true),
})

if err := transport.Open(); err != nil {
return nil, err
Expand Down
26 changes: 12 additions & 14 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,31 @@ toolchain go1.23.1

require (
github.com/apache/thrift v0.19.0
github.com/google/uuid v1.6.0
github.com/samber/lo v1.47.0
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.32.0
github.com/stretchr/testify v1.10.0
github.com/testcontainers/testcontainers-go v0.34.0
)

require (
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.11.5 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/containerd/containerd v1.7.18 // indirect
github.com/containerd/errdefs v0.1.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v27.0.3+incompatible // indirect
github.com/docker/docker v27.1.1+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/klauspost/compress v1.17.4 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
Expand All @@ -58,11 +56,11 @@ require (
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
golang.org/x/crypto v0.22.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect
google.golang.org/grpc v1.59.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading

0 comments on commit 38e8f84

Please sign in to comment.