Skip to content

Commit

Permalink
go/network: dialShard should use background context
Browse files Browse the repository at this point in the history
The context passed to dialShard is _only_ for dialing, and is often
cancelled shortly after the call.

Use it for listing shards, but "dial" using context.Background() and
require that the resulting connection is properly closed by the caller.
  • Loading branch information
jgraettinger committed Sep 23, 2024
1 parent 52ace65 commit 91ae753
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
14 changes: 7 additions & 7 deletions go/network/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ type Frontend struct {
// frontendConn is the state of a connection initiated
// by a user into the Frontend.
type frontendConn struct {
id uintptr
ctx context.Context
id uintptr

// Raw and TLS-wrapped connections to the user.
raw net.Conn
Expand Down Expand Up @@ -153,7 +152,6 @@ func (p *Frontend) Serve(ctx context.Context) (_err error) {
func (p *Frontend) serveConn(ctx context.Context, raw net.Conn) {
var conn = &frontendConn{
id: reflect.ValueOf(raw).Pointer(),
ctx: ctx,
raw: raw,
tls: tls.Server(raw, p.tlsConfig),
}
Expand All @@ -164,15 +162,18 @@ func (p *Frontend) serveConn(ctx context.Context, raw net.Conn) {
p.handshakeMu.Unlock()

// The TLS handshake machinery will next call into getTLSConfigForClient().
var err = conn.tls.HandshakeContext(conn.ctx)
var err = conn.tls.HandshakeContext(ctx)

// Clear `conn` from the map of current handshakes.
p.handshakeMu.Lock()
delete(p.handshake, conn.id)
p.handshakeMu.Unlock()

if err != nil {
handshakeCounter.WithLabelValues(err.Error()).Inc() // `err` is low-variance.
if conn.dialed != nil {
_ = conn.dialed.Close() // Handshake failed after we dialed the shard.
}
handshakeCounter.WithLabelValues("ErrHandshake").Inc()
p.serveConnErr(conn.raw, 421, "This service may only be accessed using TLS, such as through an https:// URL.\n")
return
}
Expand Down Expand Up @@ -238,7 +239,7 @@ func (p *Frontend) getTLSConfigForClient(hello *tls.ClientHelloInfo) (*tls.Confi
// think it has a good connection.
var addr = conn.raw.RemoteAddr().String()
conn.dialed, conn.dialErr = dialShard(
conn.ctx, p.networkClient, p.shardClient, conn.parsed, conn.resolved, addr)
hello.Context(), p.networkClient, p.shardClient, conn.parsed, conn.resolved, addr)
}

var nextProtos []string
Expand Down Expand Up @@ -374,7 +375,6 @@ func (p *Frontend) serveConnHTTP(user *frontendConn) {
// MaxConcurrentStreams is an important setting left as the default (100).
IdleTimeout: time.Minute,
}).ServeConn(user.tls, &http2.ServeConnOpts{
Context: user.ctx,
Handler: http.HandlerFunc(handle),
})

Expand Down
20 changes: 13 additions & 7 deletions go/network/proxy_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ func dialShard(
shardStartedCounter.WithLabelValues(labels...).Inc()

var fetched, err = listShards(ctx, shardClient, parsed, resolved.shardIDPrefix)
if err != nil {
if err == context.Canceled {
shardHandledCounter.WithLabelValues(append(labels, "ListCancelled")...).Inc()
return nil, err
} else if err != nil {
shardHandledCounter.WithLabelValues(append(labels, "ErrList")...).Inc()
return nil, fmt.Errorf("failed to list matching task shards: %w", err)
}
Expand Down Expand Up @@ -68,14 +71,17 @@ func dialShard(
var picked = fetched[primary]

rpc, err := networkClient.Proxy(
// Build a context that routes to the shard primary and encodes `claims`.
// We do not wrap `ctx` because that's only the context for dialing,
// and not the context of the long-lived connection that results.
pb.WithDispatchRoute(
pb.WithClaims(ctx, claims),
pb.WithClaims(context.Background(), claims),
picked.Route,
picked.Route.Members[picked.Route.Primary],
),
)
if err != nil {
shardHandledCounter.WithLabelValues(append(labels, "ErrProxy")...).Inc()
shardHandledCounter.WithLabelValues(append(labels, "ErrCallProxy")...).Inc()
return nil, fmt.Errorf("failed to start network proxy RPC to task shard: %w", err)
}

Expand All @@ -90,7 +96,7 @@ func dialShard(

opened, err := rpc.Recv()
if err != nil {
err = fmt.Errorf("failed to read opened response from task shard: %w", err)
err = fmt.Errorf("failed to read opened response from task shard: %w", pf.UnwrapGRPCError(err))
} else if opened.OpenResponse == nil {
err = fmt.Errorf("task shard proxy RPC is missing expected OpenResponse")
} else if status := opened.OpenResponse.Status; status != pf.TaskNetworkProxyResponse_OK {
Expand Down Expand Up @@ -127,7 +133,7 @@ func dialShard(
// Write to the shard proxy client. MUST not be called concurrently with Close.
func (pc *proxyClient) Write(b []byte) (n int, err error) {
if err = pc.rpc.Send(&pf.TaskNetworkProxyRequest{Data: b}); err != nil {
return 0, err
return 0, err // This is io.EOF if the RPC is reset.
}
pc.nWrite.Add(float64(len(b)))
return len(b), nil
Expand All @@ -148,7 +154,7 @@ func (pc *proxyClient) Read(b []byte) (n int, err error) {
} else {
shardHandledCounter.WithLabelValues(append(pc.labels, "ErrRead")...).Inc()
}
return 0, err
return 0, pf.UnwrapGRPCError(err)
} else {
pc.buf = rx.Data
pc.rxCh <- struct{}{} // Yield token.
Expand Down Expand Up @@ -182,7 +188,7 @@ func (pc *proxyClient) Close() error {
return nil
} else if err != nil {
shardHandledCounter.WithLabelValues(append(pc.labels, "ErrClose")...).Inc()
return err
return pf.UnwrapGRPCError(err)
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion go/network/sni.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/estuary/flow/go/labels"
pf "github.com/estuary/flow/go/protocols/flow"
pb "go.gazette.dev/core/broker/protocol"
pc "go.gazette.dev/core/consumer/protocol"
)
Expand Down Expand Up @@ -108,7 +109,7 @@ func listShards(ctx context.Context, shards pc.ShardClient, parsed parsedSNI, sh
err = errors.New(resp.Status.String())
}
if err != nil {
return nil, err
return nil, pf.UnwrapGRPCError(err)
}

return resp.Shards, nil
Expand Down

0 comments on commit 91ae753

Please sign in to comment.